aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar mbhuiyan <mohammad.ashraf.bhuiyan@intel.com>2018-05-16 10:49:29 -0700
committerGravatar mbhuiyan <mohammad.ashraf.bhuiyan@intel.com>2018-05-16 10:49:29 -0700
commit2acf23109aabb2952ce73dee89fe1e63b0e80961 (patch)
tree54724426fcf6d8d9a5dab57862ae749997dc5fd5
parent7a667f694fc25691d1093019a6fe4e0cd32fd344 (diff)
parent383e6d48dfd5037bcb5d56937366f1ba12b9a67d (diff)
resolving the conflict while merging master
-rw-r--r--CONTRIBUTING.md11
-rw-r--r--README.md39
-rw-r--r--tensorflow/c/c_test_util.h1
-rw-r--r--tensorflow/c/eager/tape.h4
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc2
-rw-r--r--tensorflow/cc/tools/freeze_saved_model.cc20
-rw-r--r--tensorflow/cc/tools/freeze_saved_model_test.cc50
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc15
-rw-r--r--tensorflow/compiler/jit/BUILD2
-rw-r--r--tensorflow/compiler/jit/graphcycles/graphcycles.cc14
-rw-r--r--tensorflow/compiler/jit/graphcycles/graphcycles.h4
-rw-r--r--tensorflow/compiler/jit/graphcycles/graphcycles_test.cc14
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h4
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc19
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h8
-rw-r--r--tensorflow/compiler/jit/xla_launch_util_test.cc6
-rw-r--r--tensorflow/compiler/jit/xla_tensor.cc14
-rw-r--r--tensorflow/compiler/tests/BUILD15
-rw-r--r--tensorflow/compiler/tests/listdiff_op_test.py101
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py20
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/listdiff_op.cc120
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc6
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.cc6
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.h6
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc13
-rw-r--r--tensorflow/compiler/xla/BUILD20
-rw-r--r--tensorflow/compiler/xla/client/BUILD25
-rw-r--r--tensorflow/compiler/xla/client/client.cc6
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc1574
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h1067
-rw-r--r--tensorflow/compiler/xla/client/global_data.cc2
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc8
-rw-r--r--tensorflow/compiler/xla/client/local_client.h8
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc10
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h8
-rw-r--r--tensorflow/compiler/xla/error_spec.h37
-rw-r--r--tensorflow/compiler/xla/layout_util.cc22
-rw-r--r--tensorflow/compiler/xla/layout_util.h11
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc6
-rw-r--r--tensorflow/compiler/xla/literal_comparison.cc739
-rw-r--r--tensorflow/compiler/xla/literal_comparison.h72
-rw-r--r--tensorflow/compiler/xla/literal_util.cc935
-rw-r--r--tensorflow/compiler/xla/literal_util.h1299
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc47
-rw-r--r--tensorflow/compiler/xla/map_util.h16
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.cc8
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.h7
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_client_test.cc4
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_service.cc4
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_stub.cc116
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_stub.h121
-rw-r--r--tensorflow/compiler/xla/service/BUILD92
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc50
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.cc15
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.h10
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.cc99
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.h39
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification_test.cc168
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc49
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc16
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.cc4
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.h2
-rw-r--r--tensorflow/compiler/xla/service/buffer_value_containers.h55
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.h40
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD33
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.h8
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc39
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.h4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc94
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc142
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h14
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h9
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc54
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.h12
-rw-r--r--tensorflow/compiler/xla/service/cpu/external_constant_pool.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/external_constant_pool.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc32
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.h9
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc48
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h7
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h13
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc30
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc25
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/target_machine_features.cc27
-rw-r--r--tensorflow/compiler/xla/service/cpu/target_machine_features.h55
-rw-r--r--tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h57
-rw-r--r--tensorflow/compiler/xla/service/device_memory_allocator.cc19
-rw-r--r--tensorflow/compiler/xla/service/device_memory_allocator.h28
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h6
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc97
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc65
-rw-r--r--tensorflow/compiler/xla/service/execution_tracker.cc8
-rw-r--r--tensorflow/compiler/xla/service/execution_tracker.h4
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_allocations.cc64
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_allocations.h17
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/copy_thunk.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/copy_thunk.h8
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc40
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.cc37
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.h8
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.cc13
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc28
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc41
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc84
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc46
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc36
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc40
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.cc57
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.h10
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc13
-rw-r--r--tensorflow/compiler/xla/service/gpu/sequential_thunk.cc12
-rw-r--r--tensorflow/compiler/xla/service/gpu/sequential_thunk.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk.h29
-rw-r--r--tensorflow/compiler/xla/service/gpu/tuple_thunk.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/tuple_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_thunk.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_thunk.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer.cc12
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc81
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.h55
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc66
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc57
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc102
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc19
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc68
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc155
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h46
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile_test.cc48
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc117
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h125
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc158
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc40
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc64
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.cc100
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc103
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc38
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h4
-rw-r--r--tensorflow/compiler/xla/service/human_readable_profile_builder.cc80
-rw-r--r--tensorflow/compiler/xla/service/human_readable_profile_builder.h9
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc2
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc13
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc9
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h5
-rw-r--r--tensorflow/compiler/xla/service/owning_device_memory.cc35
-rw-r--r--tensorflow/compiler/xla/service/owning_device_memory.h131
-rw-r--r--tensorflow/compiler/xla/service/service.cc200
-rw-r--r--tensorflow/compiler/xla/service/service.h133
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc36
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.cc10
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.h24
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h2
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc2
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc4
-rw-r--r--tensorflow/compiler/xla/service_interface.h114
-rw-r--r--tensorflow/compiler/xla/shape_layout.cc8
-rw-r--r--tensorflow/compiler/xla/shape_layout.h4
-rw-r--r--tensorflow/compiler/xla/shape_util.h2
-rw-r--r--tensorflow/compiler/xla/status.h2
-rw-r--r--tensorflow/compiler/xla/statusor_test.cc2
-rw-r--r--tensorflow/compiler/xla/test_helpers.h29
-rw-r--r--tensorflow/compiler/xla/tests/BUILD71
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_test.cc56
-rw-r--r--tensorflow/compiler/xla/tests/call_test.cc1
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc52
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h16
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/compilation_cache_test.cc9
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/constants_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/convolution_variants_test.cc1
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/deallocation_test.cc1
-rw-r--r--tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc1
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc114
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc1001
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h257
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util_test.cc11
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc44
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc9
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.h7
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc1
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc1
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_transfer_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/transfer_manager_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc4
-rw-r--r--tensorflow/compiler/xla/text_literal_writer.cc4
-rw-r--r--tensorflow/compiler/xla/text_literal_writer.h4
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc2
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc20
-rw-r--r--tensorflow/compiler/xla/xla_data.proto6
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/__init__.py1
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements.py9
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow.py125
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow_test.py86
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/cfg.py18
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py41
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc9
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py61
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py11
-rw-r--r--tensorflow/contrib/checkpoint/python/BUILD23
-rw-r--r--tensorflow/contrib/checkpoint/python/containers.py77
-rw-r--r--tensorflow/contrib/checkpoint/python/containers_test.py100
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py14
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt3
-rw-r--r--tensorflow/contrib/cmake/tf_core_kernels.cmake1
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake10
-rw-r--r--tensorflow/contrib/data/__init__.py4
-rw-r--r--tensorflow/contrib/data/kernels/BUILD11
-rw-r--r--tensorflow/contrib/data/kernels/csv_dataset_op.cc508
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc34
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD19
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py378
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py28
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD21
-rw-r--r--tensorflow/contrib/data/python/ops/iterator_ops.py169
-rw-r--r--tensorflow/contrib/data/python/ops/iterator_ops_test.py123
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py143
-rw-r--r--tensorflow/contrib/distributions/BUILD5
-rw-r--r--tensorflow/contrib/distributions/python/ops/autoregressive.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/batch_reshape.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/binomial.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/cauchy.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/chi2.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/deterministic.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/geometric.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/gumbel.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/half_normal.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/independent.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/inverse_gamma.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/logistic.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture_same_family.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_tril.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/negative_binomial.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/onehot_categorical.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/poisson.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/poisson_lognormal.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/quantized_distribution.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_student_t.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/wishart.py6
-rw-r--r--tensorflow/contrib/eager/python/examples/scan/BUILD25
-rw-r--r--tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py57
-rw-r--r--tensorflow/contrib/eager/python/examples/scan/scan_test.py56
-rw-r--r--tensorflow/contrib/eager/python/network.py6
-rw-r--r--tensorflow/contrib/estimator/BUILD46
-rw-r--r--tensorflow/contrib/estimator/__init__.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/baseline.py98
-rw-r--r--tensorflow/contrib/estimator/python/estimator/baseline_test.py430
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/export.py77
-rw-r--r--tensorflow/contrib/estimator/python/estimator/export_test.py42
-rw-r--r--tensorflow/contrib/estimator/python/estimator/extenders.py6
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py230
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py90
-rw-r--r--tensorflow/contrib/estimator/python/estimator/logit_fns.py4
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py4
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl_test.py2
-rw-r--r--tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc2
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py2
-rw-r--r--tensorflow/contrib/learn/BUILD3
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py4
-rw-r--r--tensorflow/contrib/lite/BUILD16
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h2
-rw-r--r--tensorflow/contrib/lite/context.h12
-rw-r--r--tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm2
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm2
-rw-r--r--tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h4
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.cc50
-rw-r--r--tensorflow/contrib/lite/g3doc/custom_operators.md4
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md18
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml1
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml47
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml46
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml3
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml7
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD32
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc40
-rw-r--r--tensorflow/contrib/lite/kernels/add.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/arg_max.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/audio_spectrogram.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn.cc168
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn_test.cc155
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd.cc28
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc65
-rw-r--r--tensorflow/contrib/lite/kernels/cast.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc20
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc20
-rw-r--r--tensorflow/contrib/lite/kernels/dequantize.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/div.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise.cc67
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise_test.cc60
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc20
-rw-r--r--tensorflow/contrib/lite/kernels/exp.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/floor.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc27
-rw-r--r--tensorflow/contrib/lite/kernels/gather.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/hashtable_lookup.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc74
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h17
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc20
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h188
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h67
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h28
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.cc15
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.h19
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/local_response_norm.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/lsh_projection.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc40
-rw-r--r--tensorflow/contrib/lite/kernels/maximum_minimum.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/mean.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/mfcc.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/neg.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/pooling.cc22
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc27
-rw-r--r--tensorflow/contrib/lite/kernels/register.h17
-rw-r--r--tensorflow/contrib/lite/kernels/reshape.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/select.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/slice.cc199
-rw-r--r--tensorflow/contrib/lite/kernels/slice_test.cc173
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_batch_nd.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_depth.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/split.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/squeeze.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/sub.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/svdf.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h16
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/transpose.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc40
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc187
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc243
-rw-r--r--tensorflow/contrib/lite/model.cc33
-rw-r--r--tensorflow/contrib/lite/model.h15
-rw-r--r--tensorflow/contrib/lite/model_test.cc5
-rw-r--r--tensorflow/contrib/lite/models/smartreply/BUILD2
-rw-r--r--tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc4
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor.cc2
-rw-r--r--tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h981
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc65
-rw-r--r--tensorflow/contrib/lite/op_resolver.cc57
-rw-r--r--tensorflow/contrib/lite/op_resolver.h94
-rw-r--r--tensorflow/contrib/lite/op_resolver_test.cc129
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs10
-rwxr-xr-x[-rw-r--r--]tensorflow/contrib/lite/schema/schema_generated.h156
-rw-r--r--tensorflow/contrib/lite/testing/BUILD2
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py104
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc37
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/format_port.h4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc30
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc3
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc3
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc21
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc165
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc59
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc31
-rw-r--r--tensorflow/contrib/lite/toco/model.h16
-rw-r--r--tensorflow/contrib/lite/toco/python/toco.i7
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_python_api.cc12
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_python_api.h7
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc3
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types.cc18
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types_test.cc13
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc17
-rw-r--r--tensorflow/contrib/lite/tools/benchmark_model.cc2
-rw-r--r--tensorflow/contrib/lite/tools/gen_op_registration_main.cc2
-rw-r--r--tensorflow/contrib/lite/tools/mutable_op_resolver.cc28
-rw-r--r--tensorflow/contrib/lite/tools/mutable_op_resolver.h39
-rw-r--r--tensorflow/contrib/lite/tools/verifier.cc13
-rw-r--r--tensorflow/contrib/lite/tools/verifier.h5
-rw-r--r--tensorflow/contrib/lite/tools/verifier_test.cc1
-rw-r--r--tensorflow/contrib/metrics/BUILD2
-rw-r--r--tensorflow/contrib/mixed_precision/BUILD32
-rw-r--r--tensorflow/contrib/mixed_precision/__init__.py34
-rw-r--r--tensorflow/contrib/mixed_precision/python/BUILD74
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_manager.py200
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py182
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py166
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py216
-rw-r--r--tensorflow/contrib/mpi/mpi_utils.h1
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py88
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py11
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py14
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms_test.py6
-rw-r--r--tensorflow/contrib/quantize/python/graph_matcher.py35
-rw-r--r--tensorflow/contrib/quantize/python/graph_matcher_test.py39
-rw-r--r--tensorflow/contrib/quantize/python/quant_ops.py6
-rw-r--r--tensorflow/contrib/quantize/python/quant_ops_test.py32
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py59
-rw-r--r--tensorflow/contrib/receptive_field/python/util/receptive_field_test.py2
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py14
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py140
-rw-r--r--tensorflow/contrib/signal/python/ops/window_ops.py4
-rw-r--r--tensorflow/contrib/slim/README.md5
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py2
-rw-r--r--tensorflow/contrib/sparsemax/BUILD2
-rw-r--r--tensorflow/contrib/summary/summary.py5
-rw-r--r--tensorflow/contrib/tensorrt/BUILD44
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc4
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc94
-rw-r--r--tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD118
-rw-r--r--tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py (renamed from tensorflow/python/keras/estimator/__init__.py)13
-rw-r--r--tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py32
-rw-r--r--tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc84
-rw-r--r--tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h35
-rw-r--r--tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc86
-rw-r--r--tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h102
-rw-r--r--tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc36
-rw-r--r--tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py95
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc4
-rw-r--r--tensorflow/contrib/tensorrt/log/trt_logger.h2
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin.cc106
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin.h74
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc78
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h102
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc125
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc42
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h46
-rw-r--r--tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc4
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD1
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py9
-rw-r--r--tensorflow/contrib/tpu/profiler/op_profile.proto5
-rw-r--r--tensorflow/contrib/tpu/python/tpu/session_support.py11
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py20
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py36
-rw-r--r--tensorflow/contrib/training/python/training/hparam.py9
-rw-r--r--tensorflow/contrib/training/python/training/hparam_test.py15
-rw-r--r--tensorflow/core/BUILD41
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CropAndResize.pbtxt27
-rw-r--r--tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV3.pbtxt64
-rw-r--r--tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt30
-rw-r--r--tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV3.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_RegexFullMatch.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/broadcaster.cc24
-rw-r--r--tensorflow/core/common_runtime/broadcaster_test.cc102
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.cc100
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.h9
-rw-r--r--tensorflow/core/common_runtime/collective_rma_local.cc23
-rw-r--r--tensorflow/core/common_runtime/device.h11
-rw-r--r--tensorflow/core/common_runtime/device_mgr.cc3
-rw-r--r--tensorflow/core/common_runtime/function.cc66
-rw-r--r--tensorflow/core/common_runtime/function_test.cc27
-rw-r--r--tensorflow/core/common_runtime/function_threadpool_test.cc14
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc26
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_test.cc2
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc10
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_manager.cc52
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_manager.h5
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc37
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_utils.h7
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.cc3
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.cc283
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.h38
-rw-r--r--tensorflow/core/common_runtime/lower_if_op_test.cc140
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc20
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h12
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime_test.cc14
-rw-r--r--tensorflow/core/common_runtime/ring_reducer.cc24
-rw-r--r--tensorflow/core/debug/BUILD1
-rw-r--r--tensorflow/core/distributed_runtime/BUILD34
-rw-r--r--tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc1
-rw-r--r--tensorflow/core/distributed_runtime/collective_rma_distributed.cc209
-rw-r--r--tensorflow/core/distributed_runtime/collective_rma_distributed.h50
-rw-r--r--tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc356
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD1
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc7
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc100
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h3
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h1
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr.cc2
-rw-r--r--tensorflow/core/distributed_runtime/test_utils.h5
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc9
-rw-r--r--tensorflow/core/distributed_runtime/worker.h3
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache_partial.cc2
-rw-r--r--tensorflow/core/distributed_runtime/worker_interface.h3
-rw-r--r--tensorflow/core/framework/attr_value_util.cc236
-rw-r--r--tensorflow/core/framework/attr_value_util.h13
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc15
-rw-r--r--tensorflow/core/framework/common_shape_fns.h3
-rw-r--r--tensorflow/core/framework/shape_inference.cc6
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc27
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc54
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc33
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc57
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h2
-rw-r--r--tensorflow/core/grappler/op_types.cc15
-rw-r--r--tensorflow/core/grappler/op_types.h4
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD40
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc108
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc242
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc1446
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h16
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc159
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.cc234
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer.cc6
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc7
-rw-r--r--tensorflow/core/grappler/optimizers/shape_optimizer.cc133
-rw-r--r--tensorflow/core/grappler/optimizers/shape_optimizer.h54
-rw-r--r--tensorflow/core/grappler/optimizers/shape_optimizer_test.cc105
-rw-r--r--tensorflow/core/grappler/optimizers/symbolic_shapes.cc60
-rw-r--r--tensorflow/core/grappler/optimizers/symbolic_shapes.h14
-rw-r--r--tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc27
-rw-r--r--tensorflow/core/grappler/utils/functions.cc10
-rw-r--r--tensorflow/core/grappler/utils/functions.h3
-rw-r--r--tensorflow/core/kernels/BUILD10
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_impl.h106
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_real.cc4
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc3
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc5
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc4
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op.cc151
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op.h5
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc183
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op_test.cc166
-rw-r--r--tensorflow/core/kernels/cudnn_rnn_ops.cc17
-rw-r--r--tensorflow/core/kernels/data/sql_dataset_ops.cc89
-rw-r--r--tensorflow/core/kernels/deep_conv2d.cc10
-rw-r--r--tensorflow/core/kernels/depthwise_conv_grad_op.cc2
-rw-r--r--tensorflow/core/kernels/dequantize_op.cc26
-rw-r--r--tensorflow/core/kernels/dequantize_op_test.cc6
-rw-r--r--tensorflow/core/kernels/functional_ops.cc63
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc139
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.h3
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op_test.cc191
-rw-r--r--tensorflow/core/kernels/ops_testutil.cc2
-rw-r--r--tensorflow/core/kernels/quantize_op.cc39
-rw-r--r--tensorflow/core/kernels/quantize_op_test.cc44
-rw-r--r--tensorflow/core/kernels/regex_full_match_op.cc59
-rw-r--r--tensorflow/core/kernels/roll_op.cc2
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc47
-rw-r--r--tensorflow/core/kernels/scoped_allocator_ops.cc2
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.cc4
-rw-r--r--tensorflow/core/lib/gtl/flatmap.h11
-rw-r--r--tensorflow/core/lib/gtl/flatmap_test.cc26
-rw-r--r--tensorflow/core/lib/gtl/flatrep.h21
-rw-r--r--tensorflow/core/lib/gtl/flatset.h17
-rw-r--r--tensorflow/core/lib/gtl/flatset_test.cc26
-rw-r--r--tensorflow/core/lib/hash/hash.h6
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt134
-rw-r--r--tensorflow/core/ops/image_ops.cc54
-rw-r--r--tensorflow/core/ops/image_ops_test.cc19
-rw-r--r--tensorflow/core/ops/math_ops.cc2
-rw-r--r--tensorflow/core/ops/nn_ops.cc3
-rw-r--r--tensorflow/core/ops/ops.pbtxt29
-rw-r--r--tensorflow/core/ops/random_ops.cc10
-rw-r--r--tensorflow/core/ops/rpc_ops.cc1
-rw-r--r--tensorflow/core/ops/string_ops.cc11
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc4
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.h10
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system_test.cc98
-rw-r--r--tensorflow/core/platform/cloud/oauth_client.cc4
-rw-r--r--tensorflow/core/platform/error.h30
-rw-r--r--tensorflow/core/platform/hadoop/hadoop_file_system.cc2
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto3
-rw-r--r--tensorflow/core/protobuf/transport_options.proto8
-rw-r--r--tensorflow/core/protobuf/worker.proto54
-rw-r--r--tensorflow/core/protobuf/worker_service.proto4
-rw-r--r--tensorflow/docs_src/community/swift.md8
-rw-r--r--tensorflow/docs_src/extend/adding_an_op.md63
-rw-r--r--tensorflow/docs_src/mobile/tflite/index.md2
-rw-r--r--tensorflow/docs_src/performance/xla/broadcasting.md4
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md655
-rw-r--r--tensorflow/docs_src/programmers_guide/eager.md6
-rw-r--r--tensorflow/docs_src/tutorials/audio_recognition.md16
-rw-r--r--tensorflow/examples/learn/text_classification_cnn.py2
-rw-r--r--tensorflow/examples/speech_commands/train.py2
-rw-r--r--tensorflow/go/op/wrappers.go128
-rw-r--r--tensorflow/python/BUILD49
-rw-r--r--tensorflow/python/client/virtual_gpu_test.py2
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py59
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py30
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py7
-rw-r--r--tensorflow/python/eager/BUILD17
-rw-r--r--tensorflow/python/eager/backprop_test.py12
-rw-r--r--tensorflow/python/eager/python_eager_op_gen.cc1047
-rw-r--r--tensorflow/python/eager/python_eager_op_gen.h43
-rw-r--r--tensorflow/python/eager/pywrap_tensor.h1
-rw-r--r--tensorflow/python/estimator/BUILD75
-rw-r--r--tensorflow/python/estimator/canned/dnn.py78
-rw-r--r--tensorflow/python/estimator/canned/head.py26
-rw-r--r--tensorflow/python/estimator/canned/metric_keys.py5
-rw-r--r--tensorflow/python/estimator/estimator.py192
-rw-r--r--tensorflow/python/estimator/estimator_lib.py1
-rw-r--r--tensorflow/python/estimator/estimator_test.py187
-rw-r--r--tensorflow/python/estimator/export/export.py99
-rw-r--r--tensorflow/python/estimator/keras.py (renamed from tensorflow/python/keras/_impl/keras/estimator.py)0
-rw-r--r--tensorflow/python/estimator/keras_test.py (renamed from tensorflow/python/keras/_impl/keras/estimator_test.py)37
-rw-r--r--tensorflow/python/estimator/run_config.py4
-rw-r--r--tensorflow/python/estimator/util.py40
-rw-r--r--tensorflow/python/feature_column/feature_column.py12
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py22
-rw-r--r--tensorflow/python/framework/c_api_util.py46
-rw-r--r--tensorflow/python/framework/c_api_util_test.py55
-rw-r--r--tensorflow/python/framework/fast_tensor_util.pyx12
-rw-r--r--tensorflow/python/framework/function.py19
-rw-r--r--tensorflow/python/framework/function_test.py52
-rw-r--r--tensorflow/python/framework/load_library.py2
-rw-r--r--tensorflow/python/framework/ops.py30
-rw-r--r--tensorflow/python/framework/ops_test.py9
-rw-r--r--tensorflow/python/framework/python_op_gen.cc1427
-rw-r--r--tensorflow/python/framework/python_op_gen.h19
-rw-r--r--tensorflow/python/framework/python_op_gen.i8
-rw-r--r--tensorflow/python/framework/python_op_gen_internal.cc800
-rw-r--r--tensorflow/python/framework/python_op_gen_main.cc9
-rw-r--r--tensorflow/python/framework/tensor_util.py12
-rwxr-xr-xtensorflow/python/keras/BUILD26
-rw-r--r--tensorflow/python/keras/__init__.py1
-rw-r--r--tensorflow/python/keras/_impl/keras/__init__.py3
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/mobilenet.py13
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/nasnet.py17
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/vgg16.py10
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/vgg19.py10
-rw-r--r--tensorflow/python/keras/_impl/keras/backend.py23
-rw-r--r--tensorflow/python/keras/_impl/keras/backend_test.py24
-rw-r--r--tensorflow/python/keras/_impl/keras/callbacks.py45
-rw-r--r--tensorflow/python/keras/_impl/keras/callbacks_test.py149
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/base_layer.py12
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/network.py21
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/saving.py184
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/saving_test.py55
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/topology_test.py27
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training.py15
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_arrays.py11
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_eager.py10
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_eager_test.py74
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_generator.py32
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_test.py1
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional.py14
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py37
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py17
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/core.py27
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/core_test.py10
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/cudnn_recurrent_test.py39
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/normalization_test.py70
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent.py108
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/wrappers.py99
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/wrappers_test.py135
-rw-r--r--tensorflow/python/keras/_impl/keras/metrics_test.py43
-rw-r--r--tensorflow/python/keras/_impl/keras/preprocessing/image.py305
-rw-r--r--tensorflow/python/keras/_impl/keras/preprocessing/image_test.py32
-rw-r--r--tensorflow/python/keras/_impl/keras/preprocessing/sequence.py15
-rw-r--r--tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py67
-rw-r--r--tensorflow/python/keras/_impl/keras/preprocessing/text.py58
-rw-r--r--tensorflow/python/keras/_impl/keras/preprocessing/text_test.py10
-rw-r--r--tensorflow/python/keras/_impl/keras/testing_utils.py3
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/generic_utils.py5
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils.py53
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/np_utils.py2
-rw-r--r--tensorflow/python/keras/utils/__init__.py1
-rw-r--r--tensorflow/python/kernel_tests/BUILD16
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/control_flow_util_test.py109
-rw-r--r--tensorflow/python/kernel_tests/conv3d_transpose_test.py17
-rw-r--r--tensorflow/python/kernel_tests/distributions/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/distributions/special_math_test.py160
-rw-r--r--tensorflow/python/kernel_tests/distributions/util_test.py56
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py194
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py4
-rw-r--r--tensorflow/python/kernel_tests/regex_full_match_op_test.py54
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py12
-rw-r--r--tensorflow/python/kernel_tests/segment_reduction_ops_test.py10
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py7
-rw-r--r--tensorflow/python/layers/base.py4
-rw-r--r--tensorflow/python/lib/core/ndarray_tensor.cc2
-rw-r--r--tensorflow/python/lib/core/ndarray_tensor_bridge.cc45
-rw-r--r--tensorflow/python/ops/control_flow_ops.py6
-rw-r--r--tensorflow/python/ops/control_flow_util.py53
-rw-r--r--tensorflow/python/ops/distributions/bernoulli.py2
-rw-r--r--tensorflow/python/ops/distributions/beta.py4
-rw-r--r--tensorflow/python/ops/distributions/categorical.py2
-rw-r--r--tensorflow/python/ops/distributions/dirichlet.py2
-rw-r--r--tensorflow/python/ops/distributions/dirichlet_multinomial.py2
-rw-r--r--tensorflow/python/ops/distributions/distribution.py3
-rw-r--r--tensorflow/python/ops/distributions/exponential.py5
-rw-r--r--tensorflow/python/ops/distributions/gamma.py4
-rw-r--r--tensorflow/python/ops/distributions/laplace.py5
-rw-r--r--tensorflow/python/ops/distributions/multinomial.py2
-rw-r--r--tensorflow/python/ops/distributions/normal.py5
-rw-r--r--tensorflow/python/ops/distributions/special_math.py45
-rw-r--r--tensorflow/python/ops/distributions/student_t.py4
-rw-r--r--tensorflow/python/ops/distributions/transformed_distribution.py2
-rw-r--r--tensorflow/python/ops/distributions/uniform.py3
-rw-r--r--tensorflow/python/ops/distributions/util.py38
-rw-r--r--tensorflow/python/ops/functional_ops.py4
-rw-r--r--tensorflow/python/ops/image_grad.py18
-rw-r--r--tensorflow/python/ops/image_ops_impl.py9
-rw-r--r--tensorflow/python/ops/init_ops.py6
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_kronecker.py4
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_test_util.py4
-rw-r--r--tensorflow/python/ops/nn_ops.py15
-rw-r--r--tensorflow/python/ops/rnn.py7
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py8
-rw-r--r--tensorflow/python/ops/string_ops.py2
-rw-r--r--tensorflow/python/ops/template.py36
-rw-r--r--tensorflow/python/ops/variable_scope.py7
-rw-r--r--tensorflow/python/training/checkpointable.py68
-rw-r--r--tensorflow/python/training/checkpointable_utils.py141
-rw-r--r--tensorflow/python/training/checkpointable_utils_test.py141
-rw-r--r--tensorflow/python/training/monitored_session.py4
-rw-r--r--tensorflow/python/training/optimizer.py11
-rw-r--r--tensorflow/python/training/saver.py132
-rw-r--r--tensorflow/python/util/function_utils.py57
-rw-r--r--tensorflow/python/util/function_utils_test.py (renamed from tensorflow/python/estimator/util_test.py)18
-rw-r--r--tensorflow/python/util/serialization.py64
-rw-r--r--tensorflow/python/util/serialization_test.py76
-rw-r--r--tensorflow/python/util/tf_inspect.py2
-rw-r--r--tensorflow/stream_executor/blas.h14
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc108
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.h6
-rw-r--r--tensorflow/stream_executor/cuda/cuda_diagnostics.cc60
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc250
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.cc152
-rw-r--r--tensorflow/stream_executor/cuda/cuda_fft.cc60
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.cc4
-rw-r--r--tensorflow/stream_executor/cuda/cuda_platform.cc8
-rw-r--r--tensorflow/stream_executor/cuda/cuda_rng.cc8
-rw-r--r--tensorflow/stream_executor/dnn.h31
-rw-r--r--tensorflow/stream_executor/host/host_gpu_executor.h10
-rw-r--r--tensorflow/stream_executor/host/host_platform.cc4
-rw-r--r--tensorflow/stream_executor/host_or_device_scalar.h2
-rw-r--r--tensorflow/stream_executor/kernel_spec.cc4
-rw-r--r--tensorflow/stream_executor/plugin_registry.cc21
-rw-r--r--tensorflow/stream_executor/stream.cc42
-rw-r--r--tensorflow/stream_executor/stream.h21
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc32
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h3
-rw-r--r--tensorflow/tensorflow.bzl15
-rw-r--r--tensorflow/tools/api/generator/BUILD1
-rw-r--r--tensorflow/tools/api/generator/create_python_api.py35
-rw-r--r--tensorflow/tools/api/generator/create_python_api_test.py9
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.image.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.callbacks.-remote-monitor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.utils.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.strings.pbtxt7
-rwxr-xr-xtensorflow/tools/ci_build/install/install_pip_packages_remote.sh4
-rw-r--r--tensorflow/tools/graph_transforms/transform_graph.cc2
-rwxr-xr-xtensorflow/tools/pip_package/build_pip_package.sh11
-rw-r--r--tensorflow/tools/pip_package/setup.py22
-rw-r--r--tensorflow/workspace.bzl8
-rw-r--r--third_party/gpus/cuda_configure.bzl2
-rw-r--r--third_party/libxsmm.BUILD2
-rw-r--r--tools/bazel.rc6
850 files changed, 26755 insertions, 14620 deletions
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 3dad41a88c..8669c25c45 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,5 +1,16 @@
# Contributing guidelines
+## Pull Request Checklist
+
+Before sending your pull requests, make sure you followed this list.
+
+- Read [contributing guidelines](CONTRIBUTING.md).
+- Read [Code of Conduct](CODE_OF_CONDUCT.md).
+- Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/).
+- Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution).
+- Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style).
+- Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests).
+
## How to become a contributor and submit your own code
### Contributor License Agreements
diff --git a/README.md b/README.md
index e1a50c87e2..6fb4486d0d 100644
--- a/README.md
+++ b/README.md
@@ -5,9 +5,9 @@
-----------------
-| **`Documentation`** | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** |
-|-----------------|---------------------|------------------|-------------------|---------------|---------------|
-| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
+| **`Documentation`** |
+|-----------------|
+| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) |
**TensorFlow** is an open source software library for numerical computation using
data flow graphs. The graph nodes represent mathematical operations, while
@@ -40,15 +40,6 @@ environment to install the nightly TensorFlow build. We support CPU and GPU
packages on Linux, Mac, and Windows.
-**Individual whl files**
-* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/))
-* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/))
-* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
-* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/))
-* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/))
-* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
-([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
-
#### *Try your first TensorFlow program*
```shell
$ python
@@ -82,6 +73,30 @@ The TensorFlow project strives to abide by generally accepted best practices in
[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486)
+
+## Continuous build status
+
+### Official Builds
+
+| Build Type | Status | Artifacts |
+| --- | --- | --- |
+| **Linux CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) |
+| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
+| **Linux XLA** | TBA | TBA |
+| **MacOS** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) |
+| **Windows CPU** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [pypi](https://pypi.org/project/tf-nightly/) |
+| **Windows GPU** | [![Status](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/badge/icon)](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/) | [pypi](https://pypi.org/project/tf-nightly-gpu/) |
+| **Android** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) [build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) |
+
+
+### Community Supported Builds
+
+| Build Type | Status | Artifacts |
+| --- | --- | --- |
+| **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA |
+| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA |
+
+
## For more information
* [TensorFlow Website](https://www.tensorflow.org)
diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h
index cd19cf8d62..c16aba666e 100644
--- a/tensorflow/c/c_test_util.h
+++ b/tensorflow/c/c_test_util.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index e9ed3395c4..dcc2357b71 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -195,7 +195,9 @@ bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
CHECK_EQ(tensor_ids.size(), dtypes.size());
for (int i = 0; i < tensor_ids.size(); ++i) {
if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
- return IsDtypeTrainable(dtypes[i]);
+ if (IsDtypeTrainable(dtypes[i])) {
+ return true;
+ }
}
}
return false;
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 1b4c7c2688..fd7b6fe662 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -31,7 +31,6 @@ using ops::AddN;
using ops::BatchMatMul;
using ops::Const;
using ops::Div;
-using ops::Greater;
using ops::MatMul;
using ops::Max;
using ops::Maximum;
@@ -46,7 +45,6 @@ using ops::RealDiv;
using ops::SquaredDifference;
using ops::Sub;
using ops::Sum;
-using ops::Where3;
// TODO(andydavis) Test gradient function against numeric gradients output.
// TODO(andydavis) As more gradients are added move common test functions
diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc
index 4ddddcb586..23e9dc40d2 100644
--- a/tensorflow/cc/tools/freeze_saved_model.cc
+++ b/tensorflow/cc/tools/freeze_saved_model.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/cc/tools/freeze_saved_model.h"
+#include <iostream>
#include <queue>
#include "tensorflow/core/framework/attr_value.pb.h"
@@ -71,6 +72,15 @@ void GetNodeNameToNodeDefMap(
}
}
+// Strips off the tensor part of the tensor_name to get the node_name.
+const string GetNodeNameFromTensorName(string tensor_name) {
+ if (tensor_name[0] == '^') {
+ tensor_name.erase(0, 1);
+ }
+ std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
+ return tensor_name_parts[0];
+}
+
// Gets the set of node names needed by `outputs` and the corresponding set of
// variable nodes to convert.
void GetReachableNodesAndVariables(
@@ -83,10 +93,8 @@ void GetReachableNodesAndVariables(
new std::unordered_set<string>({"Variable", "VariableV2", "VarHandleOp"});
std::queue<string> nodes_to_visit;
- for (const string& tensor_name : outputs) {
- // We need to strip off the tensor part to get the node name.
- std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
- nodes_to_visit.push(tensor_name_parts[0]);
+ for (const string& output_tensor_name : outputs) {
+ nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name));
}
// We do a traversal backwards from the outputs specified in the MetaGraphDef.
while (!nodes_to_visit.empty()) {
@@ -100,8 +108,8 @@ void GetReachableNodesAndVariables(
if (kVariableTypes->find(node->op()) != kVariableTypes->end()) {
variable_node_names->insert(node->name());
}
- for (const string& input : node->input()) {
- nodes_to_visit.push(input);
+ for (const string& input_tensor_name : node->input()) {
+ nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name));
}
}
}
diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc
index cd35fd3b95..979b23c3fc 100644
--- a/tensorflow/cc/tools/freeze_saved_model_test.cc
+++ b/tensorflow/cc/tools/freeze_saved_model_test.cc
@@ -351,6 +351,56 @@ TEST_F(FreezeTest, GraphDefWithNoVariables) {
GraphDefEqual(frozen_graph_def, graph_def);
}
+TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) {
+ // Tensors from operations with multiple outputs get tensor suffixes when used
+ // in input fields of following nodes, i.e. split:0, split:1.
+ // Test that we traverse those correctly.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output a = ops::Const(scope.WithOpName("a"), {10.0f, 10.0f}, {2});
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+ OutputList split = ops::Split(scope.WithOpName("split"), axis, a, 2).output;
+ Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
+ Output c = ops::Mul(scope.WithOpName("c"), split[1], b);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
+ &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ GraphDefEqual(frozen_graph_def, graph_def);
+}
+
+TEST_F(FreezeTest, GraphDefWithControlDependency) {
+ // Inputs that are control dependencies get tensor prefixes,
+ // i.e. ^control_dependency.
+ // Test that we traverse those correctly.
+ SavedModelBundle saved_model_bundle;
+ GraphDef graph_def;
+ Scope scope = Scope::NewRootScope();
+ Output source = ops::Const(scope.WithOpName("source"), 10.0f, {});
+ Output a = ops::Const(scope.WithOpName("a").WithControlDependencies(source),
+ {10.0f, 10.0f}, {2});
+ Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
+ Output c = ops::Mul(scope.WithOpName("c"), a, b);
+ TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+ TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
+ &saved_model_bundle));
+
+ GraphDef frozen_graph_def;
+ std::unordered_set<string> inputs;
+ std::unordered_set<string> outputs;
+ TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+ &outputs));
+
+ GraphDefEqual(frozen_graph_def, graph_def);
+}
+
TEST_F(FreezeTest, GraphDefWithoutDependentVariables) {
TestFreezeGraphWithoutDependentVariables(false);
}
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 309a991fc1..868d752927 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -40,7 +40,7 @@ namespace tfcompile {
namespace {
using ::testing::HasSubstr;
-using ::testing::UnorderedElementsAre;
+using ::testing::IsSupersetOf;
TEST(TFCompileTest, Add) {
AddComp add;
@@ -559,17 +559,10 @@ TEST(TFCompileTest, HloProfiling) {
auto tuple_profile_line = HasSubstr(
"%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
"%dot.0.2, f32[2,2]{1,0} %add.0.5)");
- auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
- auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
- hlo_profile_lines.erase(hlo_profile_lines.begin() + 7,
- hlo_profile_lines.end());
-
- EXPECT_THAT(
- hlo_profile_lines,
- UnorderedElementsAre(header, total_cycles_profile_line, dot_profile_line,
- add_profile_line, tuple_profile_line,
- arg0_profile_line, arg1_profile_line));
+ EXPECT_THAT(hlo_profile_lines,
+ IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
+ add_profile_line, tuple_profile_line}));
}
} // namespace
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index a6b3ce394c..df634ca3cc 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -176,6 +176,7 @@ cc_library(
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops",
+ "//tensorflow/core/kernels:identity_n_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:sendrecv_ops",
@@ -217,6 +218,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_runtime",
diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc
index bc68afb322..805bbc62c1 100644
--- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc
+++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc
@@ -354,6 +354,16 @@ bool GraphCycles::IsReachableNonConst(int32 x, int32 y) {
return reachable;
}
+bool GraphCycles::CanContractEdge(int32 a, int32 b) {
+ CHECK(HasEdge(a, b)) << "No edge exists from " << a << " to " << b;
+ RemoveEdge(a, b);
+ bool reachable = IsReachableNonConst(a, b);
+ // Restore the graph to its original state.
+ InsertEdge(a, b);
+ // If reachable, then contracting edge will cause cycle.
+ return !reachable;
+}
+
bool GraphCycles::ContractEdge(int32 a, int32 b) {
CHECK(HasEdge(a, b));
RemoveEdge(a, b);
@@ -388,4 +398,8 @@ std::unordered_set<int32> GraphCycles::Successors(int32 node) {
return rep_->nodes_[node]->out;
}
+std::unordered_set<int32> GraphCycles::Predecessors(int32 node) {
+ return rep_->nodes_[node]->in;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.h b/tensorflow/compiler/jit/graphcycles/graphcycles.h
index d11d6e27b1..44448fa3d7 100644
--- a/tensorflow/compiler/jit/graphcycles/graphcycles.h
+++ b/tensorflow/compiler/jit/graphcycles/graphcycles.h
@@ -85,6 +85,9 @@ class GraphCycles {
// and returns false.
bool ContractEdge(int32 a, int32 b);
+ // Return true if can contract edge, otherwise return false.
+ bool CanContractEdge(int32 a, int32 b);
+
// Return whether dest_node is reachable from source_node
// by following edges.
bool IsReachable(int32 source_node, int32 dest_node) const;
@@ -115,6 +118,7 @@ class GraphCycles {
bool CheckInvariants() const;
std::unordered_set<int32> Successors(int32 node);
+ std::unordered_set<int32> Predecessors(int32 node);
// ----------------------------------------------------
struct Rep;
diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc
index e47b782207..274f5938a1 100644
--- a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc
+++ b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc
@@ -494,6 +494,20 @@ TEST_F(GraphCyclesTest, ContractEdge) {
EXPECT_TRUE(g_.HasEdge(1, 4));
}
+TEST_F(GraphCyclesTest, CanContractEdge) {
+ ASSERT_TRUE(AddEdge(1, 2));
+ ASSERT_TRUE(AddEdge(1, 3));
+ ASSERT_TRUE(AddEdge(2, 3));
+ ASSERT_TRUE(AddEdge(2, 4));
+ ASSERT_TRUE(AddEdge(3, 4));
+
+ EXPECT_FALSE(g_.CanContractEdge(1, 3));
+ EXPECT_FALSE(g_.CanContractEdge(2, 4));
+ EXPECT_TRUE(g_.CanContractEdge(1, 2));
+ EXPECT_TRUE(g_.CanContractEdge(2, 3));
+ EXPECT_TRUE(g_.CanContractEdge(3, 4));
+}
+
static void BM_StressTest(int iters, int num_nodes) {
while (iters > 0) {
tensorflow::GraphCycles g;
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 498d25cf56..65c0e8577f 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/kernels/cast_op.h"
#include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/kernels/control_flow_ops.h"
+#include "tensorflow/core/kernels/identity_n_op.h"
#include "tensorflow/core/kernels/identity_op.h"
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/kernels/sendrecv_ops.h"
@@ -63,6 +64,9 @@ class XlaDeviceDummyOp : public OpKernel {
ConstantOp); \
REGISTER_KERNEL_BUILDER( \
Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("IdentityN").Device(DEVICE).TypeConstraint("T", TYPES), \
+ IdentityNOp); \
REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \
REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \
PlaceholderOp); \
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 0223f97a03..6a0f557627 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -60,19 +60,22 @@ XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
XlaAllocator::~XlaAllocator() {}
-xla::StatusOr<se::DeviceMemoryBase> XlaAllocator::Allocate(
+xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) {
- void* data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size);
+ AllocationAttributes attrs;
+ attrs.no_retry_on_failure = !retry_on_failure;
+ void* data =
+ wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs);
if (data == nullptr) {
return errors::ResourceExhausted("Out of memory while trying to allocate ",
size, " bytes.");
- } else {
- return se::DeviceMemoryBase(data, size);
}
+ return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
+ device_ordinal, this);
}
-Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) {
- wrapped_->DeallocateRaw(mem->opaque());
+Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
+ wrapped_->DeallocateRaw(mem.opaque());
return Status::OK();
}
@@ -238,7 +241,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
ctx->expected_output_dtype(i), shape, buffer, allocator);
- output.set_buffer(se::DeviceMemoryBase(nullptr, 0), {output_num});
+ output.set_buffer(xla::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
}
++output_num;
@@ -288,7 +291,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
write.type, write.shape, buffer, allocator);
- output.set_buffer(se::DeviceMemoryBase(nullptr, 0), {output_num});
+ output.set_buffer(xla::OwningDeviceMemory(), {output_num});
*variable->tensor() = output_tensor;
}
++output_num;
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index a2431253f8..4390701ccb 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -22,6 +22,8 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/owning_device_memory.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
@@ -50,9 +52,9 @@ class XlaAllocator : public xla::DeviceMemoryAllocator {
public:
XlaAllocator(const se::Platform* platform, Allocator* wrapped);
~XlaAllocator() override;
- xla::StatusOr<se::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size,
- bool retry_on_failure) override;
- Status Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) override;
+ xla::StatusOr<xla::OwningDeviceMemory> Allocate(
+ int device_ordinal, uint64 size, bool retry_on_failure) override;
+ Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
// before GPU execution takes place. Tensorflow uses the ordering of the main
diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc
index 27813efc0b..a45932403e 100644
--- a/tensorflow/compiler/jit/xla_launch_util_test.cc
+++ b/tensorflow/compiler/jit/xla_launch_util_test.cc
@@ -36,9 +36,9 @@ void BM_ExtractSubBuffer(int iters, int depth, int fan_out) {
for (int i = 0; i < iters; ++i) {
// Extract a buffer from approximately the middle of the first level of the
// tree.
- tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer,
- /*index=*/fan_out / 2,
- /*allocator=*/nullptr)
+ (void)tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer,
+ /*index=*/fan_out / 2,
+ /*allocator=*/nullptr)
.release();
}
}
diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc
index ce6456880b..a7211c9c7e 100644
--- a/tensorflow/compiler/jit/xla_tensor.cc
+++ b/tensorflow/compiler/jit/xla_tensor.cc
@@ -52,20 +52,22 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
client->backend().transfer_manager()->HostShapeToDeviceShape(
on_host_shape);
- xla::ShapedBuffer buffer(on_host_shape, on_device_shape, client->platform(),
- device_ordinal);
- for (auto& index_to_buffer : buffer.buffers()) {
+ xla::ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape,
+ client->backend().memory_allocator(),
+ device_ordinal);
+ for (auto& index_to_buffer : shaped_buffer.buffers()) {
xla::Shape subshape =
xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
uint64 size =
client->backend().transfer_manager()->GetByteSizeRequirement(subshape);
- TF_ASSIGN_OR_RETURN(index_to_buffer.second,
+ TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer,
client->backend().memory_allocator()->Allocate(
device_ordinal, size, /*retry_on_failure=*/false));
+ // Move our buffer into shaped_buffer, which takes ownership of it.
+ index_to_buffer.second = buffer.Forget();
}
- set_shaped_buffer(xla::ScopedShapedBuffer(
- std::move(buffer), client->backend().memory_allocator()));
+ set_shaped_buffer(std::move(shaped_buffer));
return Status::OK();
}
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 9791792f29..96dfc8d8f1 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -410,6 +410,21 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "listdiff_op_test",
+ size = "small",
+ srcs = ["listdiff_op_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform_test",
+ "@six_archive//:six",
+ ],
+)
+
+tf_xla_py_test(
name = "lrn_ops_test",
size = "medium",
srcs = ["lrn_ops_test.py"],
diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py
new file mode 100644
index 0000000000..45a04f0cf5
--- /dev/null
+++ b/tensorflow/compiler/tests/listdiff_op_test.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for XLA listdiff operator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class ListDiffTest(xla_test.XLATestCase):
+
+ def _testListDiff(self, x, y, out, idx):
+ for dtype in [dtypes.int32, dtypes.int64]:
+ for index_dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session() as sess:
+ x_tensor = ops.convert_to_tensor(x, dtype=dtype)
+ y_tensor = ops.convert_to_tensor(y, dtype=dtype)
+ with self.test_scope():
+ out_tensor, idx_tensor = array_ops.listdiff(
+ x_tensor, y_tensor, out_idx=index_dtype)
+ tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+ self.assertAllEqual(out, tf_out)
+ self.assertAllEqual(idx, tf_idx)
+ self.assertEqual(1, out_tensor.get_shape().ndims)
+ self.assertEqual(1, idx_tensor.get_shape().ndims)
+
+ def testBasic1(self):
+ self._testListDiff(x=[1, 2, 3, 4], y=[1, 2], out=[3, 4], idx=[2, 3])
+
+ def testBasic2(self):
+ self._testListDiff(x=[1, 2, 3, 4], y=[2], out=[1, 3, 4], idx=[0, 2, 3])
+
+ def testBasic3(self):
+ self._testListDiff(x=[1, 4, 3, 2], y=[4, 2], out=[1, 3], idx=[0, 2])
+
+ def testDuplicates(self):
+ self._testListDiff(x=[1, 2, 4, 3, 2, 3, 3, 1],
+ y=[4, 2],
+ out=[1, 3, 3, 3, 1],
+ idx=[0, 3, 5, 6, 7])
+
+ def testRandom(self):
+ num_random_tests = 10
+ int_low = -7
+ int_high = 8
+ max_size = 50
+ for _ in xrange(num_random_tests):
+ x_size = np.random.randint(max_size + 1)
+ x = np.random.randint(int_low, int_high, size=x_size)
+ y_size = np.random.randint(max_size + 1)
+ y = np.random.randint(int_low, int_high, size=y_size)
+ out_idx = [(entry, pos) for pos, entry in enumerate(x) if entry not in y]
+ if out_idx:
+ out, idx = map(list, zip(*out_idx))
+ else:
+ out = []
+ idx = []
+ self._testListDiff(list(x), list(y), out, idx)
+
+ def testFullyOverlapping(self):
+ self._testListDiff(x=[1, 2, 3, 4], y=[1, 2, 3, 4], out=[], idx=[])
+
+ def testNonOverlapping(self):
+ self._testListDiff(x=[1, 2, 3, 4],
+ y=[5, 6],
+ out=[1, 2, 3, 4],
+ idx=[0, 1, 2, 3])
+
+ def testEmptyX(self):
+ self._testListDiff(x=[], y=[1, 2], out=[], idx=[])
+
+ def testEmptyY(self):
+ self._testListDiff(x=[1, 2, 3, 4], y=[], out=[1, 2, 3, 4], idx=[0, 1, 2, 3])
+
+ def testEmptyXY(self):
+ self._testListDiff(x=[], y=[], out=[], idx=[])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index ba79f393a8..57a1d9b9e4 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -209,7 +209,9 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.expm1,
np.array([[-1, 1]], dtype=dtype),
- expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype))
+ expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype),
+ rtol=1e-5,
+ atol=1e-6)
self._assertOpOutputMatchesExpected(
math_ops.floor,
@@ -251,12 +253,12 @@ class UnaryOpsTest(XLATestCase):
np.array([[1, 2]], dtype=dtype),
expected=np.array([[0.540297, -0.41614]], dtype=dtype))
- # TODO(b/34703906): improve log1p implementation and make tolerance
- # tighter.
self._assertOpOutputMatchesExpected(
math_ops.log1p,
np.array([[1e-14, 1e-15, 0.6]], dtype=dtype),
- expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)))
+ expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)),
+ rtol=1e-4,
+ atol=1e-6)
self._assertOpOutputMatchesExpected(
math_ops.rint,
@@ -419,7 +421,9 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.expm1,
np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype),
- expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)))
+ expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)),
+ rtol=1e-6,
+ atol=1e-6)
self._assertOpOutputMatchesExpected(
math_ops.reciprocal,
@@ -441,13 +445,13 @@ class UnaryOpsTest(XLATestCase):
np.array([[5j, 3 - 2j]], dtype=dtype),
expected=np.cos(np.array([[5j, 3 - 2j]], dtype=dtype)))
- # TODO(b/34703906): improve log1p implementation and make tolerance
- # tighter.
self._assertOpOutputMatchesExpected(
math_ops.log1p,
np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype),
expected=np.log1p(
- np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)))
+ np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)),
+ rtol=1e-4,
+ atol=1e-6)
val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)
self._assertOpOutputMatchesExpected(
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 85ab4c41bf..e6da157c11 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -45,6 +45,7 @@ tf_kernel_library(
"image_resize_ops.cc",
"index_ops.cc",
"l2loss_op.cc",
+ "listdiff_op.cc",
"lrn_ops.cc",
"matmul_op.cc",
"matrix_band_part_op.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc
new file mode 100644
index 0000000000..0388b4c830
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific ListDiff Op. This only supports constant DT_INT32 and DT_INT64
+// input.
+
+#include <unordered_set>
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace {
+
+constexpr std::array<DataType, 2> kListDiffTypes = {DT_INT32, DT_INT64};
+
+// ListDiffOp is an XLA kernel that supports constant-only x and y input.
+class ListDiffOp : public XlaOpKernel {
+ public:
+ explicit ListDiffOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(0)),
+ errors::InvalidArgument("ListDiff expects x as a vector, not ",
+ context->InputShape(0).DebugString()));
+
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(1)),
+ errors::InvalidArgument("ListDiff expects y as a vector, not ",
+ context->InputShape(1).DebugString()));
+
+ DataType val_type = context->expected_output_dtype(0);
+ DataType idx_type = context->expected_output_dtype(1);
+
+ Status status;
+ switch (val_type) {
+ case DT_INT32:
+ status = ListDiffWithIndexType<int32>(context, idx_type);
+ break;
+ case DT_INT64:
+ status = ListDiffWithIndexType<int64>(context, idx_type);
+ break;
+ default:
+ // This should never happen since we restrict this kernel to only match
+ // inputs with supported Tensor datatype.
+ status = errors::InvalidArgument("ListDiff expects x and y as either ",
+ "int32 or int64, not ",
+ DataTypeString(val_type));
+ }
+ OP_REQUIRES_OK(context, status);
+ }
+
+ private:
+ template <typename Tval, typename Tidx>
+ Status ListDiff(XlaOpKernelContext* context) {
+ std::vector<int64> x_input, y_input;
+ TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(0, &x_input));
+ TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(1, &y_input));
+
+ std::unordered_set<Tval> y_input_set;
+ y_input_set.reserve(y_input.size());
+ for (auto y : y_input) {
+ y_input_set.insert(y);
+ }
+
+ std::vector<Tval> val_output;
+ std::vector<Tidx> idx_output;
+ auto x_size = x_input.size();
+ for (Tidx i = 0; i < x_size; ++i) {
+ if (y_input_set.count(x_input[i]) > 0) {
+ continue;
+ }
+ val_output.push_back(x_input[i]);
+ idx_output.push_back(i);
+ }
+
+ context->SetOutput(0, context->builder()->ConstantR1<Tval>(val_output));
+ context->SetOutput(1, context->builder()->ConstantR1<Tidx>(idx_output));
+ return Status::OK();
+ }
+
+ template <typename Tval>
+ Status ListDiffWithIndexType(XlaOpKernelContext* context, DataType idx_type) {
+ switch (idx_type) {
+ case DT_INT32:
+ return ListDiff<Tval, int32>(context);
+ case DT_INT64:
+ return ListDiff<Tval, int64>(context);
+ default:
+ return errors::InvalidArgument(
+ "ListDiff expects idx_out as either int32 or int64, not ",
+ DataTypeString(idx_type));
+ }
+ }
+};
+
+REGISTER_XLA_OP(Name("ListDiff")
+ .TypeConstraint("T", kListDiffTypes)
+ .CompileTimeConstInput("x")
+ .CompileTimeConstInput("y"),
+ ListDiffOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index a4f50f52eb..3f6e218bcc 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -100,8 +100,7 @@ XLAJIT_MAKE_UNARY(Cosh,
XLAJIT_MAKE_UNARY(Sin, b->Sin(x));
XLAJIT_MAKE_UNARY(Exp, b->Exp(x));
-// TODO(b/34703906): use a more accurate implementation of expm1.
-XLAJIT_MAKE_UNARY(Expm1, b->Sub(b->Exp(x), XlaHelpers::One(b, input_type(0))));
+XLAJIT_MAKE_UNARY(Expm1, b->Expm1(x));
XLAJIT_MAKE_UNARY(Floor, b->Floor(x));
XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x));
@@ -115,8 +114,7 @@ XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x));
XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x));
XLAJIT_MAKE_UNARY(Log, b->Log(x));
-// TODO(b/34703906): use a more accurate implementation of log1p.
-XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x)));
+XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x));
XLAJIT_MAKE_UNARY(Invert, b->Not(x));
XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x));
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 04ad3694a0..ef12b1618b 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -141,7 +141,6 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/tests:client_library_test_base",
diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc
index 2c3cd658e0..43e1c1e9fe 100644
--- a/tensorflow/compiler/tf2xla/literal_util.cc
+++ b/tensorflow/compiler/tf2xla/literal_util.cc
@@ -40,7 +40,7 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
return Status::OK();
}
-Status CopyLiteralToHostTensor(const xla::Literal& literal,
+Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal,
Tensor* host_tensor) {
TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) &&
xla::ShapeUtil::ElementsIn(literal.shape()) ==
@@ -63,8 +63,8 @@ Status CopyLiteralToHostTensor(const xla::Literal& literal,
return Status::OK();
}
-Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
- Tensor* host_tensor) {
+Status LiteralToHostTensor(const xla::LiteralSlice& literal,
+ DataType target_type, Tensor* host_tensor) {
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape));
*host_tensor = Tensor(target_type, shape);
diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h
index f283b02368..220bec1553 100644
--- a/tensorflow/compiler/tf2xla/literal_util.h
+++ b/tensorflow/compiler/tf2xla/literal_util.h
@@ -36,13 +36,13 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal);
// derivable from the type of <literal>, because multiple tensorflow types map
// to the same XLA type (e.g. INT32 and QINT32 both map to INT32 in
// XLA).
-Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
- Tensor* host_tensor);
+Status LiteralToHostTensor(const xla::LiteralSlice& literal,
+ DataType target_type, Tensor* host_tensor);
// Copies the contents of 'literal' to a previously allocated tensor
// 'host_tensor'. The tensor and the literal must have the same number of
// elements and the same type.
-Status CopyLiteralToHostTensor(const xla::Literal& literal,
+Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal,
Tensor* host_tensor);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 6b8918b261..4382ffe6ba 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -225,7 +225,7 @@ TEST_F(XlaCompilerTest, Simple) {
xla::Literal::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({expected0.get()});
- xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
@@ -320,7 +320,8 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
xla::Literal::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({expected0.get()});
- xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+ EXPECT_TRUE(
+ xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
{
@@ -355,7 +356,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
xla::Literal::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected =
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
- xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal);
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
}
}
@@ -523,7 +524,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
{output_base.get(), output_grad1.get(), output_grad2.get()});
std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({output_read.get(), output_resource.get()});
- xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
// Tests compilation and execution of a graph that adds two tensors.
@@ -746,7 +747,7 @@ TEST_F(XlaCompilerTest, Variables) {
xla::Literal::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
- xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
// Tests a simple graph that reads and writes a variable, with a
@@ -811,7 +812,7 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
xla::Literal::CreateR1<int32>({26, 66, 34, 401});
std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
- xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
} // namespace
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index dbf14f32bc..92936b17c8 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -99,9 +99,9 @@ cc_library(
hdrs = ["service_interface.h"],
visibility = [":friends"],
deps = [
+ ":status",
":xla_data_proto",
":xla_proto",
- "//tensorflow/core:lib",
],
)
@@ -245,6 +245,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":protobuf_util",
+ ":status",
":status_macros",
":statusor",
":types",
@@ -331,6 +332,23 @@ tf_cc_test(
)
cc_library(
+ name = "error_spec",
+ hdrs = ["error_spec.h"],
+)
+
+cc_library(
+ name = "literal_comparison",
+ srcs = ["literal_comparison.cc"],
+ hdrs = ["literal_comparison.h"],
+ deps = [
+ ":error_spec",
+ ":literal_util",
+ ":util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "metric_table_report",
srcs = ["metric_table_report.cc"],
hdrs = ["metric_table_report.h"],
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index aac3273d5f..989cd61d9f 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -179,31 +179,6 @@ cc_library(
)
cc_library(
- name = "computation_builder",
- srcs = ["computation_builder.cc"],
- hdrs = ["computation_builder.h"],
- deps = [
- ":client",
- ":computation",
- ":global_data",
- ":padding",
- "//tensorflow/compiler/xla:array",
- "//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:array3d",
- "//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla:xla_proto",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
name = "sharding_builder",
srcs = ["sharding_builder.cc"],
hdrs = ["sharding_builder.h"],
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 328e1b8fa8..0a79b3cf27 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -336,7 +336,7 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
ExecuteParallelResponse response;
VLOG(1) << "making execute-parallel request: " << request.ShortDebugString();
- tensorflow::Status s = stub_->ExecuteParallel(&request, &response);
+ Status s = stub_->ExecuteParallel(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
@@ -372,7 +372,7 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
ExecuteParallelResponse response;
VLOG(1) << "making execute-graph-parallel request: "
<< request.ShortDebugString();
- tensorflow::Status s = stub_->ExecuteGraphParallel(&request, &response);
+ Status s = stub_->ExecuteGraphParallel(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
@@ -401,7 +401,7 @@ StatusOr<std::vector<DeviceHandle>> Client::GetDeviceHandles(
GetDeviceHandlesResponse response;
VLOG(1) << "making get device request: " << request.ShortDebugString();
- tensorflow::Status s = stub_->GetDeviceHandles(&request, &response);
+ Status s = stub_->GetDeviceHandles(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
deleted file mode 100644
index 83c7cb1744..0000000000
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ /dev/null
@@ -1,1574 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/client/computation_builder.h"
-
-#include <stddef.h>
-#include <array>
-#include <numeric>
-#include <set>
-#include <vector>
-
-#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/compiler/xla/xla.pb.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/protobuf.h"
-
-namespace xla {
-
-ComputationBuilder::ComputationBuilder(Client* client,
- const string& computation_name)
- : name_(computation_name), client_(client) {}
-
-ComputationBuilder::~ComputationBuilder() {}
-
-void ComputationBuilder::NoteError(const Status& error) {
- if (die_immediately_on_error_) {
- LOG(FATAL) << "error building computation: " << error;
- }
-
- if (first_error_.ok()) {
- first_error_ = error;
- first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
- }
-}
-
-std::unique_ptr<ComputationBuilder> ComputationBuilder::CreateSubBuilder(
- const string& computation_name) {
- auto sub_builder = MakeUnique<ComputationBuilder>(client_, computation_name);
- sub_builder->parent_builder_ = this;
- sub_builder->die_immediately_on_error_ = die_immediately_on_error_;
- return sub_builder;
-}
-
-Status ComputationBuilder::PrepareComputation() {
- TF_RETURN_IF_ERROR(first_error_);
-
- if (!computation_.IsNull()) {
- return Status::OK();
- }
-
- ComputationRequest request;
- request.set_name(name_);
- ComputationResponse response;
-
- VLOG(2) << "making computation request";
- Status s = client_->stub()->Computation(&request, &response);
- VLOG(2) << "done with computation request";
-
- if (!s.ok()) {
- NoteError(s);
- return first_error_;
- }
-
- computation_ = Computation(client_->stub(), response.computation());
- return Status::OK();
-}
-
-Status ComputationBuilder::RunOp(OpRequest* op_request,
- OpResponse* op_response) {
- TF_RETURN_IF_ERROR(first_error_);
- TF_RETURN_IF_ERROR(PrepareComputation());
-
- // Fill in fields that are set on every OpRequest.
- *op_request->mutable_computation() = computation_.handle();
- *op_request->mutable_metadata() = metadata_;
- if (sharding_) {
- *op_request->mutable_sharding() = *sharding_;
- }
-
- const string& op_name =
- OpRequest::descriptor()->FindFieldByNumber(op_request->op_case())->name();
- VLOG(2) << "running op request: " << op_name;
- Status status = client_->stub()->Op(op_request, op_response);
- VLOG(2) << "done with op request: " << op_name;
- return status;
-}
-
-void ComputationBuilder::RunOpAndNoteError(OpRequest* op_request) {
- OpResponse op_response;
- Status status = RunOp(op_request, &op_response);
- if (!status.ok()) {
- NoteError(status);
- }
-}
-
-ComputationDataHandle ComputationBuilder::RunOpAndParseResponse(
- OpRequest* op_request) {
- OpResponse op_response;
- Status status = RunOp(op_request, &op_response);
- if (!status.ok()) {
- NoteError(status);
- return ComputationDataHandle();
- }
- if (op_response.output().handle() == 0) {
- NoteError(InternalError("No output handle"));
- return ComputationDataHandle();
- }
- return op_response.output();
-}
-
-bool ComputationBuilder::MakeWindow(
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation, Window* window) {
- const auto verify_size = [&](const size_t x, const char* x_name) {
- if (x == 0 || x == window_dimensions.size()) {
- return true;
- } else {
- NoteError(InvalidArgument(
- "%s", tensorflow::strings::StrCat(
- "Window has different number of window dimensions than of ",
- x_name, "\nNumber of window dimensions: ",
- window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
- "\n")
- .c_str())); //
- return false;
- }
- };
- if (!verify_size(window_strides.size(), "window strides") ||
- !verify_size(padding.size(), "padding entries") ||
- !verify_size(lhs_dilation.size(), "lhs dilation factors") ||
- !verify_size(rhs_dilation.size(), "rhs dilation factors")) {
- return false;
- }
-
- window->Clear();
- for (size_t i = 0; i < window_dimensions.size(); i++) {
- auto dim = window->add_dimensions();
- dim->set_size(window_dimensions[i]);
- if (!window_strides.empty()) {
- dim->set_stride(window_strides[i]);
- } else {
- dim->set_stride(1);
- }
- if (!padding.empty()) {
- dim->set_padding_low(padding[i].first);
- dim->set_padding_high(padding[i].second);
- } else {
- dim->set_padding_low(0);
- dim->set_padding_high(0);
- }
- if (!lhs_dilation.empty()) {
- dim->set_base_dilation(lhs_dilation[i]);
- } else {
- dim->set_base_dilation(1);
- }
- if (!rhs_dilation.empty()) {
- dim->set_window_dilation(rhs_dilation[i]);
- } else {
- dim->set_window_dilation(1);
- }
- dim->set_window_reversal(false);
- }
- return true;
-}
-
-ComputationDataHandle ComputationBuilder::ConstantLiteral(
- const Literal& literal) {
- OpRequest op_request;
- ConstantRequest* request = op_request.mutable_constant_request();
- *request->mutable_literal() = literal.ToProto();
- VLOG(3) << "created constant: " << request->literal().ShortDebugString();
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number,
- const Shape& shape,
- const string& name) {
- OpRequest op_request;
- ParameterRequest* request = op_request.mutable_parameter_request();
- *request->mutable_shape() = shape;
- request->set_parameter(parameter_number);
- request->set_name(name);
- return RunOpAndParseResponse(&op_request);
-}
-
-StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShapeWithoutNoteError(
- const ComputationDataHandle& operand) {
- GetLocalShapeRequest request;
- *request.mutable_computation() = computation_.handle();
- *request.mutable_operand() = operand;
- GetLocalShapeResponse response;
-
- VLOG(2) << "making get-shape request";
- TF_RETURN_IF_ERROR(client_->stub()->GetLocalShape(&request, &response));
- VLOG(2) << "done with request";
-
- TF_RET_CHECK(response.has_shape());
- std::unique_ptr<Shape> shape = WrapUnique(response.release_shape());
- TF_RET_CHECK(shape != nullptr);
- return std::move(shape);
-}
-
-StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShape(
- const ComputationDataHandle& operand) {
- TF_RETURN_IF_ERROR(first_error_);
-
- auto status_or_shape = GetShapeWithoutNoteError(operand);
- if (!status_or_shape.ok()) {
- NoteError(status_or_shape.status());
- return first_error_;
- }
- return status_or_shape;
-}
-
-StatusOr<ProgramShape> ComputationBuilder::GetProgramShape() {
- TF_RETURN_IF_ERROR(first_error_);
-
- GetComputationShapeRequest request;
- *request.mutable_computation() = computation_.handle();
- GetComputationShapeResponse response;
-
- VLOG(2) << "making get-program-shape-request";
- Status status = client_->stub()->GetComputationShape(&request, &response);
- VLOG(2) << "done with get-program-shape-request";
-
- if (!status.ok()) {
- first_error_ = status;
- return status;
- }
-
- TF_RET_CHECK(response.has_program_shape());
- return std::move(*response.mutable_program_shape());
-}
-
-ComputationDataHandle ComputationBuilder::Slice(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides) {
- OpRequest op_request;
- SliceRequest* request = op_request.mutable_slice_request();
- *request->mutable_operand() = operand;
- for (int64 index : start_indices) {
- request->add_start_indices(index);
- }
- for (int64 index : limit_indices) {
- request->add_limit_indices(index);
- }
- for (int64 index : strides) {
- request->add_strides(index);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::SliceInDim(
- const ComputationDataHandle& operand, int64 start_index, int64 limit_index,
- int64 stride, int64 dimno) {
- StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
- if (!shape_status.ok()) {
- NoteError(shape_status.status());
- return ComputationDataHandle{};
- }
- const Shape& shape = *shape_status.ValueOrDie();
- std::vector<int64> starts(ShapeUtil::Rank(shape), 0);
- std::vector<int64> limits(shape.dimensions().begin(),
- shape.dimensions().end());
- std::vector<int64> strides(ShapeUtil::Rank(shape), 1);
- starts[dimno] = start_index;
- limits[dimno] = limit_index;
- strides[dimno] = stride;
- return Slice(operand, starts, limits, strides);
-}
-
-ComputationDataHandle ComputationBuilder::DynamicSlice(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes) {
- OpRequest op_request;
- DynamicSliceRequest* request = op_request.mutable_dynamic_slice_request();
- *request->mutable_operand() = operand;
- *request->mutable_start_indices() = start_indices;
- for (int64 index : slice_sizes) {
- request->add_slice_sizes(index);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::DynamicUpdateSlice(
- const ComputationDataHandle& operand, const ComputationDataHandle& update,
- const ComputationDataHandle& start_indices) {
- OpRequest op_request;
- DynamicUpdateSliceRequest* request =
- op_request.mutable_dynamic_update_slice_request();
- *request->mutable_operand() = operand;
- *request->mutable_update() = update;
- *request->mutable_start_indices() = start_indices;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::ConcatInDim(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- int64 dimension) {
- OpRequest op_request;
- ConcatenateRequest* request = op_request.mutable_concatenate_request();
- for (const ComputationDataHandle& operand : operands) {
- *request->add_operands() = operand;
- }
- request->set_dimension(dimension);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Broadcast(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
- OpRequest op_request;
- BroadcastRequest* request = op_request.mutable_broadcast_request();
- *request->mutable_operand() = operand;
- for (int64 size : broadcast_sizes) {
- request->add_broadcast_sizes(size);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Pad(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& padding_value,
- const PaddingConfig& padding_config) {
- OpRequest op_request;
- PadRequest* request = op_request.mutable_pad_request();
- *request->mutable_operand() = operand;
- *request->mutable_padding_value() = padding_value;
- *request->mutable_padding_config() = padding_config;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Reshape(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
- OpRequest op_request;
- ReshapeRequest* request = op_request.mutable_reshape_request();
- *request->mutable_operand() = operand;
- for (int64 dimension : dimensions) {
- request->add_dimensions(dimension);
- }
- for (int64 new_size : new_sizes) {
- request->add_new_sizes(new_size);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Reshape(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes) {
- if (!first_error_.ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
- if (!shape.ok()) {
- return ComputationDataHandle();
- }
- std::vector<int64> dimensions(shape.ValueOrDie()->dimensions().size());
- std::iota(dimensions.begin(), dimensions.end(), 0);
- return Reshape(operand, dimensions, new_sizes);
-}
-
-ComputationDataHandle ComputationBuilder::Collapse(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
- if (!first_error_.ok()) {
- return ComputationDataHandle();
- }
-
- // Don't support out-of-order collapse here.
- // Checks that the collapsed dimensions are in order and consecutive.
- for (tensorflow::gtl::ArraySlice<int64>::size_type i = 1;
- i < dimensions.size(); ++i) {
- if (dimensions[i] - 1 != dimensions[i - 1]) {
- NoteError(InvalidArgument(
- "Collapsed dimensions are not in order and consecutive."));
- return ComputationDataHandle();
- }
- }
-
- // Create a new sizes vector from the old shape, replacing the collapsed
- // dimensions by the product of their sizes.
- StatusOr<std::unique_ptr<Shape>> shape_or_status = GetShape(operand);
- if (!shape_or_status.ok()) {
- return ComputationDataHandle();
- }
- std::unique_ptr<Shape> original_shape = shape_or_status.ConsumeValueOrDie();
-
- VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
- VLOG(3) << "dims to collapse: "
- << tensorflow::str_util::Join(dimensions, ",");
-
- if (dimensions.size() <= 1) {
- // Not collapsing anything, trivially we can return the operand versus
- // enqueueing a trivial reshape.
- return operand;
- }
-
- std::vector<int64> new_sizes;
- for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) {
- if (i <= dimensions.front() || i > dimensions.back()) {
- new_sizes.push_back(original_shape->dimensions(i));
- } else {
- new_sizes.back() *= original_shape->dimensions(i);
- }
- }
-
- VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",")
- << "]";
-
- return Reshape(operand, new_sizes);
-}
-
-void ComputationBuilder::Trace(const string& tag,
- const ComputationDataHandle& operand) {
- OpRequest op_request;
- TraceRequest* request = op_request.mutable_trace_request();
- request->set_tag(tag);
- *request->mutable_operand() = operand;
- RunOpAndNoteError(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Select(
- const ComputationDataHandle& pred, const ComputationDataHandle& on_true,
- const ComputationDataHandle& on_false) {
- return TernaryOp(TRIOP_SELECT, pred, on_true, on_false);
-}
-
-ComputationDataHandle ComputationBuilder::Tuple(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
- OpRequest op_request;
- VariadicOpRequest* request = op_request.mutable_variadic_op_request();
- request->set_varop(VAROP_TUPLE);
- for (const ComputationDataHandle& operand : elements) {
- *request->add_operands() = operand;
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::GetTupleElement(
- const ComputationDataHandle& tuple_data, int64 index) {
- OpRequest op_request;
- GetTupleElementRequest* request =
- op_request.mutable_get_tuple_element_request();
- *request->mutable_operand() = tuple_data;
- request->set_index(index);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Eq(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_EQ, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Ne(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_NE, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Ge(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_GE, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Gt(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_GT, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Le(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_LE, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Lt(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_LT, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Dot(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) {
- StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
- if (!lhs_shape_or_status.ok()) {
- return ComputationDataHandle();
- }
- std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
-
- DotDimensionNumbers dimension_numbers;
- dimension_numbers.add_lhs_contracting_dimensions(
- lhs_shape->dimensions_size() == 1 ? 0 : 1);
- dimension_numbers.add_rhs_contracting_dimensions(0);
- return DotGeneral(lhs, rhs, dimension_numbers);
-}
-
-ComputationDataHandle ComputationBuilder::DotGeneral(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- const DotDimensionNumbers& dimension_numbers) {
- OpRequest op_request;
- DotRequest* request = op_request.mutable_dot_request();
- *request->mutable_lhs() = lhs;
- *request->mutable_rhs() = rhs;
- *request->mutable_dimension_numbers() = dimension_numbers;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Conv(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
- return ConvWithGeneralDimensions(
- lhs, rhs, window_strides, padding,
- CreateDefaultConvDimensionNumbers(window_strides.size()));
-}
-
-ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
- return ConvGeneral(lhs, rhs, window_strides, padding,
- CreateDefaultConvDimensionNumbers(window_strides.size()));
-}
-
-bool ComputationBuilder::VerifyConvolution(
- const Shape& lhs_shape, const Shape& rhs_shape,
- const ConvolutionDimensionNumbers& dimension_numbers) {
- if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) {
- NoteError(
- InvalidArgument("Convolution arguments must have same number of "
- "dimensions. Got: %s and %s",
- ShapeUtil::HumanString(lhs_shape).c_str(),
- ShapeUtil::HumanString(rhs_shape).c_str()));
- return false;
- }
- int num_dims = ShapeUtil::Rank(lhs_shape);
- if (num_dims < 2) {
- NoteError(InvalidArgument(
- "Convolution expects argument arrays with >= 3 dimensions. "
- "Got: %s and %s",
- ShapeUtil::HumanString(lhs_shape).c_str(),
- ShapeUtil::HumanString(rhs_shape).c_str()));
- return false;
- }
- int num_spatial_dims = num_dims - 2;
-
- const auto check_spatial_dimensions =
- [&](const char* const field_name,
- const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
- numbers) {
- if (numbers.size() != num_spatial_dims) {
- NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
- num_spatial_dims, field_name,
- numbers.size()));
- return false;
- }
- for (int i = 0; i < numbers.size(); ++i) {
- if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
- NoteError(
- InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
- field_name, i, numbers.Get(i)));
- return false;
- }
- }
- return true;
- };
- return check_spatial_dimensions(
- "input_spatial_dimensions",
- dimension_numbers.input_spatial_dimensions()) &&
- check_spatial_dimensions(
- "kernel_spatial_dimensions",
- dimension_numbers.kernel_spatial_dimensions()) &&
- check_spatial_dimensions(
- "output_spatial_dimensions",
- dimension_numbers.output_spatial_dimensions());
-}
-
-ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers) {
- if (!first_error_.ok() || !PrepareComputation().ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
- if (!lhs_shape_or_status.ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs);
- if (!rhs_shape_or_status.ok()) {
- return ComputationDataHandle();
- }
-
- std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
- std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie();
-
- if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) {
- NoteError(InternalError("failed to verify convolution"));
- return ComputationDataHandle();
- }
-
- std::vector<int64> base_area_dimensions(
- dimension_numbers.input_spatial_dimensions_size());
- for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
- ++i) {
- base_area_dimensions[i] =
- lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
- }
-
- std::vector<int64> window_dimensions(
- dimension_numbers.kernel_spatial_dimensions_size());
- for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
- window_dimensions[i] =
- rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
- }
-
- return ConvGeneral(lhs, rhs, window_strides,
- MakePadding(base_area_dimensions, window_dimensions,
- window_strides, padding),
- dimension_numbers);
-}
-
-ComputationDataHandle ComputationBuilder::ConvGeneral(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers) {
- return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
- dimension_numbers);
-}
-
-ComputationDataHandle ComputationBuilder::ConvGeneralDilated(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers) {
- if (!first_error_.ok() || !PrepareComputation().ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
- if (!lhs_shape_or_status.ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs);
- if (!rhs_shape_or_status.ok()) {
- return ComputationDataHandle();
- }
-
- std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
- std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie();
- if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) {
- // Error is recorded in VerifyConvolution.
- return ComputationDataHandle();
- }
-
- std::vector<int64> window_dimensions(
- dimension_numbers.kernel_spatial_dimensions_size());
- for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
- window_dimensions[i] =
- rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
- }
-
- OpRequest op_request;
- ConvolveRequest* request = op_request.mutable_convolve_request();
- *request->mutable_lhs() = lhs;
- *request->mutable_rhs() = rhs;
- *request->mutable_dimension_numbers() = dimension_numbers;
-
- if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation,
- rhs_dilation, request->mutable_window())) {
- // Error is recorded in MakeWindow.
- return ComputationDataHandle();
- }
-
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Fft(
- const ComputationDataHandle& operand, const FftType fft_type,
- const tensorflow::gtl::ArraySlice<int64> fft_length) {
- OpRequest op_request;
- FftRequest* request = op_request.mutable_fft_request();
- *request->mutable_operand() = operand;
- request->set_fft_type(fft_type);
- for (int64 dim_len : fft_length) {
- request->add_fft_length(dim_len);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape,
- const string& config) {
- OpRequest op_request;
- InfeedRequest* request = op_request.mutable_infeed_request();
- *request->mutable_shape() = shape;
- *request->mutable_config() = config;
- return RunOpAndParseResponse(&op_request);
-}
-
-void ComputationBuilder::Outfeed(const ComputationDataHandle& operand,
- const Shape& shape_with_layout,
- const string& outfeed_config) {
- OpRequest op_request;
- OutfeedRequest* request = op_request.mutable_outfeed_request();
- request->set_outfeed_config(outfeed_config);
- *request->mutable_operand() = operand;
- *request->mutable_shape() = shape_with_layout;
- RunOpAndNoteError(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Call(
- const Computation& computation,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) {
- OpRequest op_request;
- CallRequest* request = op_request.mutable_call_request();
- *request->mutable_to_apply() = computation.handle();
- for (const ComputationDataHandle& operand : operands) {
- *request->add_operands() = operand;
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::CustomCall(
- const string& call_target_name,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- const Shape& shape) {
- OpRequest op_request;
- CustomCallRequest* request = op_request.mutable_custom_call_request();
- request->set_call_target_name(call_target_name);
- for (const ComputationDataHandle& operand : operands) {
- *request->add_operands() = operand;
- }
- *request->mutable_shape() = shape;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::HostCompute(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- const string& channel_name, int64 cost_estimate_ns, const Shape& shape) {
- OpRequest op_request;
- HostComputeRequest* request = op_request.mutable_host_compute_request();
- for (const ComputationDataHandle& operand : operands) {
- *request->add_operands() = operand;
- }
- *request->mutable_shape() = shape;
- request->set_channel_name(channel_name);
- request->set_cost_estimate_ns(cost_estimate_ns);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Complex(
- const ComputationDataHandle& real, const ComputationDataHandle& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Conj(
- const ComputationDataHandle& operand) {
- return Complex(Real(operand), Neg(Imag(operand)));
-}
-
-ComputationDataHandle ComputationBuilder::Add(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_ADD, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Sub(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_SUB, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Mul(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_MUL, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Div(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_DIV, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Rem(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_REM, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Max(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_MAX, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Min(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::And(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Or(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions);
-}
-
-// TODO(b/65209188): Create a dedicated lowering for Xor
-ComputationDataHandle ComputationBuilder::Xor(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return Or(And(Not(lhs), rhs, broadcast_dimensions),
- And(lhs, Not(rhs), broadcast_dimensions));
-}
-
-ComputationDataHandle ComputationBuilder::Not(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_NOT, operand);
-}
-
-ComputationDataHandle ComputationBuilder::ShiftLeft(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::ShiftRightArithmetic(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::ShiftRightLogical(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Abs(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_ABS, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Atan2(
- const ComputationDataHandle& y, const ComputationDataHandle& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::Exp(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_EXP, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Floor(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_FLOOR, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Ceil(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_CEIL, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Round(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_ROUND_NEAREST_AFZ, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Log(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_LOG, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Sign(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_SIGN, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Cos(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_COS, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Sin(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_SIN, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Tanh(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_TANH, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Real(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_REAL, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Imag(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_IMAG, operand);
-}
-
-ComputationDataHandle ComputationBuilder::IsFinite(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_IS_FINITE, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Transpose(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> permutation) {
- OpRequest op_request;
- TransposeRequest* request = op_request.mutable_transpose_request();
- *request->mutable_operand() = operand;
- for (int64 dimension : permutation) {
- request->add_dimensions(dimension);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Rev(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
- OpRequest op_request;
- ReverseRequest* request = op_request.mutable_reverse_request();
- *request->mutable_operand() = operand;
- for (int64 dimension : dimensions) {
- request->add_dimensions(dimension);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Sort(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_SORT, operand);
-}
-
-ComputationDataHandle ComputationBuilder::SqrtF32(
- const ComputationDataHandle& operand) {
- return BinaryOp(BINOP_POW, operand, ConstantR0<float>(0.5),
- /*broadcast_dimensions=*/{});
-}
-
-ComputationDataHandle ComputationBuilder::Pow(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return BinaryOp(BINOP_POW, lhs, rhs, broadcast_dimensions);
-}
-
-ComputationDataHandle ComputationBuilder::ConvertElementType(
- const ComputationDataHandle& operand, PrimitiveType new_element_type) {
- if (!first_error_.ok() || !PrepareComputation().ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
- if (!shape_status.ok()) {
- return ComputationDataHandle();
- }
- std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
-
- OpRequest op_request;
- ConvertRequest* request = op_request.mutable_convert_request();
- *request->mutable_operand() = operand;
- request->set_new_element_type(new_element_type);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::BitcastConvertType(
- const ComputationDataHandle& operand, PrimitiveType new_element_type) {
- if (!first_error_.ok() || !PrepareComputation().ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
- if (!shape_status.ok()) {
- return ComputationDataHandle();
- }
- std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
-
- OpRequest op_request;
- ConvertRequest* request = op_request.mutable_bitcast_convert_request();
- *request->mutable_operand() = operand;
- request->set_new_element_type(new_element_type);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::SquareF32(
- const ComputationDataHandle& operand) {
- return BinaryOp(BINOP_POW, operand, ConstantR0<float>(2.0),
- /*broadcast_dimensions=*/{});
-}
-
-ComputationDataHandle ComputationBuilder::ReciprocalF32(
- const ComputationDataHandle& operand) {
- return BinaryOp(BINOP_POW, operand, ConstantR0<float>(-1.0),
- /*broadcast_dimensions=*/{});
-}
-
-ComputationDataHandle ComputationBuilder::Neg(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_NEGATE, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Clz(
- const ComputationDataHandle& operand) {
- return UnaryOp(UNOP_CLZ, operand);
-}
-
-ComputationDataHandle ComputationBuilder::Clamp(
- const ComputationDataHandle& min, const ComputationDataHandle& operand,
- const ComputationDataHandle& max) {
- return TernaryOp(TRIOP_CLAMP, min, operand, max);
-}
-
-ComputationDataHandle ComputationBuilder::UnaryOp(
- UnaryOperation unop, const ComputationDataHandle& operand) {
- OpRequest op_request;
- UnaryOpRequest* request = op_request.mutable_unary_op_request();
- request->set_unop(unop);
- *request->mutable_operand() = operand;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::BinaryOp(
- BinaryOperation binop, const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- OpRequest op_request;
- BinaryOpRequest* request = op_request.mutable_binary_op_request();
- request->set_binop(binop);
- *request->mutable_lhs() = lhs;
- *request->mutable_rhs() = rhs;
- for (int64 dimension : broadcast_dimensions) {
- request->add_broadcast_dimensions(dimension);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::RngOp(
- RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,
- const Shape& shape) {
- OpRequest op_request;
- RngRequest* request = op_request.mutable_rng_request();
- request->set_distribution(distribution);
- for (const ComputationDataHandle& param : parameters) {
- *request->add_parameter() = param;
- }
- *request->mutable_shape() = shape;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::TernaryOp(
- TernaryOperation triop, const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) {
- OpRequest op_request;
- TernaryOpRequest* request = op_request.mutable_ternary_op_request();
- request->set_triop(triop);
- *request->mutable_lhs() = lhs;
- *request->mutable_rhs() = rhs;
- *request->mutable_ehs() = ehs;
- return RunOpAndParseResponse(&op_request);
-}
-
-Status ComputationBuilder::SetReturnValue(
- const ComputationDataHandle& operand) {
- TF_RETURN_IF_ERROR(first_error_);
-
- SetReturnValueRequest request;
- *request.mutable_computation() = computation_.handle();
- *request.mutable_operand() = operand;
-
- SetReturnValueResponse response;
-
- VLOG(2) << "making set-handle-to-execute request";
- Status s = client_->stub()->SetReturnValue(&request, &response);
- VLOG(2) << "done with request";
-
- if (!s.ok()) {
- NoteError(s);
- return first_error_;
- }
-
- return Status::OK();
-}
-
-StatusOr<bool> ComputationBuilder::IsConstant(
- const ComputationDataHandle& operand, int64 num_parameters) {
- TF_RETURN_IF_ERROR(first_error_);
-
- IsConstantRequest request;
- *request.mutable_computation() = computation_.handle();
- *request.mutable_operand() = operand;
- request.set_num_parameters(num_parameters);
- IsConstantResponse response;
-
- VLOG(2) << "making IsConstant request";
- Status s = client_->stub()->IsConstant(&request, &response);
- VLOG(2) << "done with request";
-
- if (!s.ok()) {
- return s;
- }
- return response.is_constant();
-}
-
-StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant(
- const ComputationDataHandle& operand, const Layout* output_layout,
- tensorflow::gtl::ArraySlice<Literal> parameters) {
- TF_RETURN_IF_ERROR(first_error_);
-
- ComputeConstantRequest request;
- *request.mutable_computation() = computation_.handle();
- *request.mutable_operand() = operand;
- if (output_layout != nullptr) {
- *request.mutable_output_layout() = *output_layout;
- }
- for (const auto& param : parameters) {
- *request.add_parameters() = param.ToProto();
- }
-
- ComputeConstantResponse response;
-
- VLOG(2) << "making compute-constant request";
- Status s = client_->stub()->ComputeConstant(&request, &response);
- VLOG(2) << "done with request";
-
- if (!s.ok()) {
- return s;
- }
-
- VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
-
- if (!response.has_literal()) {
- return InternalError(
- "no computed literal in the provided response in ComputeConstant "
- "request");
- }
- return Literal::CreateFromProto(response.literal());
-}
-
-ComputationDataHandle ComputationBuilder::Map(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) {
- OpRequest op_request;
- MapRequest* request = op_request.mutable_map_request();
- for (const ComputationDataHandle& operand : operands) {
- *request->add_operands() = operand;
- }
- *request->mutable_to_apply() = computation.handle();
- for (int64 dimension : dimensions) {
- request->add_dimensions(dimension);
- }
- for (const ComputationDataHandle& sop : static_operands) {
- *request->add_static_operands() = sop;
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::RngNormal(
- const ComputationDataHandle& mu, const ComputationDataHandle& sigma,
- const Shape& shape) {
- return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
-}
-
-ComputationDataHandle ComputationBuilder::RngUniform(
- const ComputationDataHandle& a, const ComputationDataHandle& b,
- const Shape& shape) {
- return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
-}
-
-ComputationDataHandle ComputationBuilder::While(
- const Computation& condition, const Computation& body,
- const ComputationDataHandle& init) {
- OpRequest op_request;
- WhileRequest* request = op_request.mutable_while_request();
- *request->mutable_condition() = condition.handle();
- *request->mutable_body() = body.handle();
- *request->mutable_init() = init;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Gather(
- const ComputationDataHandle& input,
- const ComputationDataHandle& gather_indices,
- const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds) {
- OpRequest op_request;
- GatherRequest* gather_request = op_request.mutable_gather_request();
- *gather_request->mutable_input() = input;
- *gather_request->mutable_gather_indices() = gather_indices;
- *gather_request->mutable_dimension_numbers() = dimension_numbers;
- for (int64 window_bound : window_bounds) {
- gather_request->add_window_bounds(window_bound);
- }
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Conditional(
- const ComputationDataHandle& predicate,
- const ComputationDataHandle& true_operand,
- const Computation& true_computation,
- const ComputationDataHandle& false_operand,
- const Computation& false_computation) {
- OpRequest op_request;
- ConditionalRequest* request = op_request.mutable_conditional_request();
- *request->mutable_predicate() = predicate;
- *request->mutable_true_operand() = true_operand;
- *request->mutable_true_computation() = true_computation.handle();
- *request->mutable_false_operand() = false_operand;
- *request->mutable_false_computation() = false_computation.handle();
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Reduce(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value, const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
- OpRequest op_request;
- ReduceRequest* request = op_request.mutable_reduce_request();
- *request->mutable_operand() = operand;
- *request->mutable_init_value() = init_value;
- for (int64 dimension : dimensions_to_reduce) {
- request->add_dimensions(dimension);
- }
- *request->mutable_to_apply() = computation.handle();
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::ReduceAll(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value, const Computation& computation) {
- if (!first_error_.ok() || !PrepareComputation().ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
- if (!shape.ok()) {
- return ComputationDataHandle();
- }
-
- std::vector<int64> all_dimnos(ShapeUtil::Rank(*shape.ValueOrDie()));
- std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
- return Reduce(operand, init_value, computation, all_dimnos);
-}
-
-ComputationDataHandle ComputationBuilder::ReduceWindow(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value, const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
- if (!first_error_.ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
- if (!shape.ok()) {
- return ComputationDataHandle();
- }
-
- Status padding_valid =
- ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()),
- window_dimensions, window_strides);
- if (!padding_valid.ok()) {
- first_error_ = padding_valid;
- return ComputationDataHandle();
- }
-
- std::vector<std::pair<int64, int64>> padding_values =
- MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()),
- window_dimensions, window_strides, padding);
- return ReduceWindowWithGeneralPadding(operand, init_value, computation,
- window_dimensions, window_strides,
- padding_values);
-}
-
-ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value, const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
- OpRequest op_request;
- ReduceWindowRequest* request = op_request.mutable_reduce_window_request();
- *request->mutable_operand() = operand;
- *request->mutable_to_apply() = computation.handle();
- *request->mutable_init_value() = init_value;
-
- if (!MakeWindow(window_dimensions, window_strides, padding, {}, {},
- request->mutable_window())) {
- NoteError(InternalError("failed to make window"));
- return ComputationDataHandle();
- }
-
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::BatchNormTraining(
- const ComputationDataHandle& operand, const ComputationDataHandle& scale,
- const ComputationDataHandle& offset, float epsilon, int64 feature_index) {
- OpRequest op_request;
- BatchNormTrainingRequest* request =
- op_request.mutable_batch_norm_training_request();
- *request->mutable_operand() = operand;
- *request->mutable_scale() = scale;
- *request->mutable_offset() = offset;
- request->set_epsilon(epsilon);
- request->set_feature_index(feature_index);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::BatchNormInference(
- const ComputationDataHandle& operand, const ComputationDataHandle& scale,
- const ComputationDataHandle& offset, const ComputationDataHandle& mean,
- const ComputationDataHandle& variance, float epsilon, int64 feature_index) {
- OpRequest op_request;
- BatchNormInferenceRequest* request =
- op_request.mutable_batch_norm_inference_request();
- *request->mutable_operand() = operand;
- *request->mutable_scale() = scale;
- *request->mutable_offset() = offset;
- *request->mutable_mean() = mean;
- *request->mutable_variance() = variance;
- request->set_epsilon(epsilon);
- request->set_feature_index(feature_index);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::BatchNormGrad(
- const ComputationDataHandle& operand, const ComputationDataHandle& scale,
- const ComputationDataHandle& batch_mean,
- const ComputationDataHandle& batch_var,
- const ComputationDataHandle& grad_output, float epsilon,
- int64 feature_index) {
- OpRequest op_request;
- BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request();
- *request->mutable_operand() = operand;
- *request->mutable_scale() = scale;
- *request->mutable_mean() = batch_mean;
- *request->mutable_variance() = batch_var;
- *request->mutable_grad_output() = grad_output;
- request->set_epsilon(epsilon);
- request->set_feature_index(feature_index);
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::CrossReplicaSum(
- const ComputationDataHandle& operand) {
- OpRequest op_request;
- CrossReplicaSumRequest* request =
- op_request.mutable_cross_replica_sum_request();
- *request->mutable_operand() = operand;
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::SelectAndScatter(
- const ComputationDataHandle& operand, const Computation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ComputationDataHandle& source,
- const ComputationDataHandle& init_value, const Computation& scatter) {
- if (!first_error_.ok()) {
- return ComputationDataHandle();
- }
-
- StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
- if (!shape.ok()) {
- return ComputationDataHandle();
- }
- return SelectAndScatterWithGeneralPadding(
- operand, select, window_dimensions, window_strides,
- MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()),
- window_dimensions, window_strides, padding),
- source, init_value, scatter);
-}
-
-ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding(
- const ComputationDataHandle& operand, const Computation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ComputationDataHandle& source,
- const ComputationDataHandle& init_value, const Computation& scatter) {
- OpRequest op_request;
- SelectAndScatterRequest* request =
- op_request.mutable_select_and_scatter_request();
- *request->mutable_operand() = operand;
- *request->mutable_select() = select.handle();
- *request->mutable_source() = source;
- *request->mutable_init_value() = init_value;
- *request->mutable_scatter() = scatter.handle();
-
- if (!MakeWindow(window_dimensions, window_strides, padding, {}, {},
- request->mutable_window())) {
- NoteError(InternalError("failed to make window"));
- return ComputationDataHandle();
- }
-
- return RunOpAndParseResponse(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::ReducePrecision(
- const ComputationDataHandle& operand, const int exponent_bits,
- const int mantissa_bits) {
- OpRequest op_request;
- ReducePrecisionRequest* request =
- op_request.mutable_reduce_precision_request();
- *request->mutable_operand() = operand;
- request->set_exponent_bits(exponent_bits);
- request->set_mantissa_bits(mantissa_bits);
- return RunOpAndParseResponse(&op_request);
-}
-
-void ComputationBuilder::Send(const ComputationDataHandle& operand,
- const ChannelHandle& handle) {
- OpRequest op_request;
- SendRequest* request = op_request.mutable_send_request();
- *request->mutable_operand() = operand;
- *request->mutable_channel_handle() = handle;
- *op_request.mutable_computation() = computation_.handle();
- RunOpAndNoteError(&op_request);
-}
-
-ComputationDataHandle ComputationBuilder::Recv(const Shape& shape,
- const ChannelHandle& handle) {
- OpRequest op_request;
- RecvRequest* request = op_request.mutable_recv_request();
- *request->mutable_shape() = shape;
- *request->mutable_channel_handle() = handle;
- return RunOpAndParseResponse(&op_request);
-}
-
-Computation ComputationBuilder::BuildAndNoteError() {
- DCHECK(parent_builder_ != nullptr);
- auto build_status = Build();
- if (!build_status.ok()) {
- parent_builder_->NoteError(
- AddStatus(build_status.status(),
- tensorflow::strings::StrCat("error from: ", name_)));
- return Computation();
- }
- return build_status.ConsumeValueOrDie();
-}
-
-StatusOr<Computation> ComputationBuilder::Build() {
- if (!first_error_.ok()) {
- string backtrace;
- first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
- return AppendStatus(first_error_, backtrace);
- }
-
- if (computation_.IsNull()) {
- return FailedPrecondition("no computation was built");
- }
-
- return {std::move(computation_)};
-}
-
-/* static */ ConvolutionDimensionNumbers
-ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
- ConvolutionDimensionNumbers dimension_numbers;
- dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
- dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
- dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
- dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
- dimension_numbers.set_kernel_output_feature_dimension(
- kConvKernelOutputDimension);
- dimension_numbers.set_kernel_input_feature_dimension(
- kConvKernelInputDimension);
- for (int i = 0; i < num_spatial_dims; ++i) {
- dimension_numbers.add_input_spatial_dimensions(i + 2);
- dimension_numbers.add_kernel_spatial_dimensions(i + 2);
- dimension_numbers.add_output_spatial_dimensions(i + 2);
- }
- return dimension_numbers;
-}
-
-/* static */ StatusOr<ConvolutionDimensionNumbers>
-ComputationBuilder::CreateConvDimensionNumbers(
- int64 input_batch, int64 input_feature, int64 input_first_spatial,
- int64 input_second_spatial, int64 output_batch, int64 output_feature,
- int64 output_first_spatial, int64 output_second_spatial,
- int64 kernel_output_feature, int64 kernel_input_feature,
- int64 kernel_first_spatial, int64 kernel_second_spatial) {
- if (std::set<int64>({input_batch, input_feature, input_first_spatial,
- input_second_spatial})
- .size() != 4) {
- return FailedPrecondition(
- "dimension numbers for the input are not unique: (%lld, %lld, %lld, "
- "%lld)",
- input_batch, input_feature, input_first_spatial, input_second_spatial);
- }
- if (std::set<int64>({kernel_output_feature, kernel_input_feature,
- kernel_first_spatial, kernel_second_spatial})
- .size() != 4) {
- return FailedPrecondition(
- "dimension numbers for the weight are not unique: (%lld, %lld, %lld, "
- "%lld)",
- kernel_output_feature, kernel_input_feature, kernel_first_spatial,
- kernel_second_spatial);
- }
- if (std::set<int64>({output_batch, output_feature, output_first_spatial,
- output_second_spatial})
- .size() != 4) {
- return FailedPrecondition(
- "dimension numbers for the output are not unique: (%lld, %lld, %lld, "
- "%lld)",
- output_batch, output_feature, output_first_spatial,
- output_second_spatial);
- }
- ConvolutionDimensionNumbers dimension_numbers;
- dimension_numbers.set_input_batch_dimension(input_batch);
- dimension_numbers.set_input_feature_dimension(input_feature);
- dimension_numbers.add_input_spatial_dimensions(input_first_spatial);
- dimension_numbers.add_input_spatial_dimensions(input_second_spatial);
- dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature);
- dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature);
- dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial);
- dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial);
- dimension_numbers.set_output_batch_dimension(output_batch);
- dimension_numbers.set_output_feature_dimension(output_feature);
- dimension_numbers.add_output_spatial_dimensions(output_first_spatial);
- dimension_numbers.add_output_spatial_dimensions(output_second_spatial);
- return dimension_numbers;
-}
-
-} // namespace xla
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
deleted file mode 100644
index ac1eb915cc..0000000000
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ /dev/null
@@ -1,1067 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
-#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
-
-#include <functional>
-#include <initializer_list>
-#include <memory>
-#include <string>
-#include <utility>
-
-#include "tensorflow/compiler/xla/array.h"
-#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/array3d.h"
-#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/client.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/global_data.h"
-#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/bitmap.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/stacktrace.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace xla {
-
-// Wraps an XLA client with a convenient interface for building up
-// computations. Any errors encountered in building up the computation are
-// deferred from being handled until Build() is called.
-//
-// Thread-compatible.
-//
-// TODO(b/74197823): Deprecated. Use XlaBuilder instead.
-class ComputationBuilder {
- public:
- // client: client in which to build the computation.
- // computation_name: name to use for the built computation.
- ComputationBuilder(Client* client, const string& computation_name);
-
- ~ComputationBuilder();
-
- // Returns the client the builder was initialized with.
- Client* client() const { return client_; }
-
- // Returns the computation name.
- const string& name() const { return name_; }
-
- // Sets OpMetadata that will be added to all instructions until cleared.
- //
- // OpMetadata is often applied to a series of XLA HLO instructions. As a
- // result, OpMetadata is set on the Computation Builder. All subsequent
- // instructions generated via this Computation Builder will have the same
- // OpMetadata attached until a call to ClearOpMetadata.
- void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
-
- // Clears the HloMetadata state.
- void ClearOpMetadata() { metadata_.Clear(); }
-
- // Sets an OpSharding that will be attached to all instructions until cleared.
- void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
-
- // Clears the sharding. Ops will be sharded according to the default placement
- // policy.
- void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; }
-
- // Returns the OpSharding that will be attached to all instructions.
- const tensorflow::gtl::optional<OpSharding>& sharding() const {
- return sharding_;
- }
-
- // Sets the builder to a mode where it will die immediately when an error is
- // encountered, rather than producing it in a deferred fashion when Build() is
- // called (which is the default).
- void set_die_immediately_on_error(bool enabled) {
- die_immediately_on_error_ = enabled;
- }
-
- // Enqueues a "retrieve parameter value" instruction for a parameter that was
- // passed to the computation.
- ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape,
- const string& name);
-
- // Retrieves the (inferred) shape of the operand in the computation.
- StatusOr<std::unique_ptr<Shape>> GetShape(
- const ComputationDataHandle& operand);
-
- // Retrieves the (inferred) result for the current computation's shape.
- StatusOr<ProgramShape> GetProgramShape();
-
- // Enqueues a constant with the value of the given literal onto the
- // computation.
- ComputationDataHandle ConstantLiteral(const Literal& literal);
-
- // Enqueues a constant onto the computation. Methods are templated on the
- // native host type (NativeT) which corresponds to a specific XLA
- // PrimitiveType as given in the following table:
- //
- // Native Type PrimitiveType
- // -----------------------------
- // bool PRED
- // int32 S32
- // int64 S64
- // uint32 U32
- // uint64 U64
- // float F32
- // double F64
- //
- // Note: not all primitive types defined in xla_data.proto have a
- // corresponding native type yet.
- template <typename NativeT>
- ComputationDataHandle ConstantR0(NativeT value);
- template <typename NativeT>
- ComputationDataHandle ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values);
- ComputationDataHandle ConstantR1(const tensorflow::core::Bitmap& values);
- template <typename NativeT>
- ComputationDataHandle ConstantR2(
- std::initializer_list<std::initializer_list<NativeT>> values);
- template <typename NativeT>
- ComputationDataHandle ConstantFromArrayWithLayout(
- const Array<NativeT>& values, const Layout& layout);
- template <typename NativeT>
- ComputationDataHandle ConstantFromArray(const Array<NativeT>& values);
- template <typename NativeT>
- ComputationDataHandle ConstantR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout);
- template <typename NativeT>
- ComputationDataHandle ConstantR2FromArray2D(const Array2D<NativeT>& values);
- template <typename NativeT>
- ComputationDataHandle ConstantR3FromArray3DWithLayout(
- const Array3D<NativeT>& values, const Layout& layout);
- template <typename NativeT>
- ComputationDataHandle ConstantR3FromArray3D(const Array3D<NativeT>& values);
- template <typename NativeT>
- ComputationDataHandle ConstantR4FromArray4DWithLayout(
- const Array4D<NativeT>& values, const Layout& layout);
- template <typename NativeT>
- ComputationDataHandle ConstantR4FromArray4D(const Array4D<NativeT>& values);
-
- // Enqueues a rank one constant (vector) onto the computation. The vector has
- // size 'length' and every element has the value 'value'.
- template <typename NativeT>
- ComputationDataHandle ConstantR1(int64 length, NativeT value);
-
- // Adds dimensions to an array by duplicating the data in the array.
- //
- // The new dimensions are inserted on the left, i.e. if
- // broadcast_sizes has values {a0, ..., aN} and the operand shape
- // has dimensions {b0, ..., bM} then the shape of the output has
- // dimensions {a0, ..., aN, b0, ..., bM}.
- //
- // The new dimensions index into copies of the operand, i.e.
- //
- // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
- ComputationDataHandle Broadcast(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
-
- // Enqueues a pad operation onto the computation that pads the given value on
- // the edges as well as between the elements of the input. padding_config
- // specifies the padding amount for each dimension.
- ComputationDataHandle Pad(const ComputationDataHandle& operand,
- const ComputationDataHandle& padding_value,
- const PaddingConfig& padding_config);
-
- // Enqueues an operation onto the computation that flattens the operand based
- // on the dimension order (major/slowest-varying to minor/fastest-varying)
- // given, followed by reshaping it into the shape with the given dimension
- // sizes (also major to minor). Conceptually, this is a limited form of
- // "shape casting".
- ComputationDataHandle Reshape(const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
-
- // Enqueues an operation onto the computation that collapses the operand, from
- // first to last dimension (C order), then reshapes it to the given dimension
- // sizes. Conceptually, this is a limited form of "shape casting".
- ComputationDataHandle Reshape(const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> new_sizes);
-
- // Wrapper for Reshape.
- // Enqueues an operation to collapse the provided dimensions; e.g. an
- // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
- // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
- // be a consecutive, in-order subsequence of the operand dimensions.
- //
- // Note that collapsing a single dimension does nothing:
- //
- // {256} collapsing {0} => {256}
- // {1} collapsing {0} => {1}
- //
- // Collapsing multiple dimensions produces a single result dimension:
- //
- // {256, 2} collapsing {0,1} => {512}
- // {256, 2, 3} collapsing {0,1} => {512, 3}
- //
- // This could potentially cause data to be moved -- it provides a more
- // structured form of reshaping than an arbitrary Reshape operation.
- ComputationDataHandle Collapse(const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
-
- // Enqueues a slice operation onto the computation that slices the operand
- // from the start indices to the limit indices; e.g.
- //
- // x
- // [ 0 1 2 3 ]
- // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
- // [ 8 9 a b ]
- //
- // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
- // range notation.
- // The strides parameter determines the stride over the slice
- ComputationDataHandle Slice(const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
-
- // Enqueues a slice operation in a given dimension, taking all other
- // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
- // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
- // for:
- //
- // array[:, 2:4:1, :]
- ComputationDataHandle SliceInDim(const ComputationDataHandle& operand,
- int64 start_index, int64 limit_index,
- int64 stride, int64 dimno);
-
- // Enqueues a slice operation onto the computation that slices the 'operand'
- // from dynamic start indices which are passed in 'start_indices'.
- // The size of the slice in each dimension is passed in 'slice_sizes',
- // which specify the end point of exclusive slice intervals in each
- // dimension [start, start + size).
- // The shape of 'start_indices' must be rank == 1, with dimension size
- // equal to the rank of the 'operand'.
- // Slice index calculations are computed modulo input dimension sizes to
- // prevent dynamic start indices from generating out-of-bound array accesses.
- ComputationDataHandle DynamicSlice(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& start_indices,
- tensorflow::gtl::ArraySlice<int64> slice_sizes);
-
- // Enqueues a dynamic update slice operation onto the computation, which
- // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
- // The shape of 'update' determines the shape of the slice of 'operand'
- // which is updated.
- // The indices specified in 'start_indices' specify the offset of the slice
- // of 'operand' which is updated.
- //
- // update = {10, 11} // calculated at runtime.
- // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
- // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
- // [7 8 9] [7 8 9 ]
- //
- // The shape of 'start_indices' must be rank == 1, with dimension size
- // equal to the rank of the 'operand'.
- // Slice index calculations are computed modulo update dimension sizes to
- // prevent dynamic start indices from generating out-of-bound array accesses.
- ComputationDataHandle DynamicUpdateSlice(
- const ComputationDataHandle& operand, const ComputationDataHandle& update,
- const ComputationDataHandle& start_indices);
-
- // Enqueues a concatenate instruction onto the computation. 'operands' must
- // have >= 1 entry.
- ComputationDataHandle ConcatInDim(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- int64 dimension);
-
- // Enqueue a tracing operation onto the computation; the computation will emit
- // a logging message with the operand.
- void Trace(const string& tag, const ComputationDataHandle& operand);
-
- // Enqueues a conditional-move-like select operation onto the computation;
- // predicated on pred, selects between on_true and on_false.
- ComputationDataHandle Select(const ComputationDataHandle& pred,
- const ComputationDataHandle& on_true,
- const ComputationDataHandle& on_false);
-
- // Enqueues a tuple-creation instruction onto the computation.
- ComputationDataHandle Tuple(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
-
- // Enqueues a tuple-element-get instruction onto the computation.
- ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data,
- int64 index);
-
- // Enqueues an equal-to comparison instruction onto the computation.
- ComputationDataHandle Eq(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a not-equal comparison instruction onto the computation.
- ComputationDataHandle Ne(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a greater-or-equal comparison instruction onto the computation.
- ComputationDataHandle Ge(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a greater-than comparison instruction onto the computation.
- ComputationDataHandle Gt(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a less-than comparison instruction onto the computation.
- ComputationDataHandle Lt(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a less-or-equal comparison instruction onto the computation.
- ComputationDataHandle Le(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a dot instruction onto the computation.
- ComputationDataHandle Dot(const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs);
-
- // Enqueues a general dot instruction onto the computation.
- ComputationDataHandle DotGeneral(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- const DotDimensionNumbers& dimension_numbers);
-
- // Default dimension numbers used for a 2D convolution.
- static constexpr int64 kConvBatchDimension = 0;
- static constexpr int64 kConvFeatureDimension = 1;
- static constexpr int64 kConvFirstSpatialDimension = 2;
- static constexpr int64 kConvSecondSpatialDimension = 3;
- static constexpr int64 kConvKernelOutputDimension = 0;
- static constexpr int64 kConvKernelInputDimension = 1;
- static constexpr int64 kConvKernelFirstSpatialDimension = 2;
- static constexpr int64 kConvKernelSecondSpatialDimension = 3;
-
- // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
- // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
- // the kernel operand
- // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
- static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
- int num_spatial_dims = 2);
-
- // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an
- // error if either the input or the weight dimension numbers have conflicts.
- static StatusOr<ConvolutionDimensionNumbers> CreateConvDimensionNumbers(
- int64 input_batch, int64 input_feature, int64 input_first_spatial,
- int64 input_second_spatial, int64 output_batch, int64 output_feature,
- int64 output_first_spatial, int64 output_second_spatial,
- int64 kernel_output_feature, int64 kernel_input_feature,
- int64 kernel_first_spatial, int64 kernel_second_spatial);
-
- // Enqueues a convolution instruction onto the computation, which uses the
- // default convolution dimension numbers.
- ComputationDataHandle Conv(const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding);
-
- // Enqueues a convolution instruction onto the computation, with the caller
- // provided padding configuration in the format returned by MakePadding().
- ComputationDataHandle ConvWithGeneralPadding(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
-
- // Enqueues a convolution instruction onto the computation, with the caller
- // provided dimension numbers configuration.
- ComputationDataHandle ConvWithGeneralDimensions(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
-
- // Enqueues a convolution instruction onto the computation, with the caller
- // provided padding configuration as well as the dimension numbers.
- ComputationDataHandle ConvGeneral(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
-
- // Enqueues a convolution instruction onto the computation, with the caller
- // provided padding configuration, dilation factors and dimension numbers.
- ComputationDataHandle ConvGeneralDilated(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
-
- // Enqueues an FFT instruction onto the computation, of the given type and
- // with the given FFT length.
- ComputationDataHandle Fft(const ComputationDataHandle& operand,
- FftType fft_type,
- tensorflow::gtl::ArraySlice<int64> fft_length);
-
- // Enqueues an infeed instruction onto the computation, which writes data of
- // the given shape to the infeed buffer of the device.
- ComputationDataHandle Infeed(const Shape& shape, const string& config = "");
-
- // Enqueues an outfeed instruction onto the computation. This instruction
- // generates outgoing data transfers for the given data.
- //
- // shape_with_layout communicates the laid out shape that we want to outfeed
- // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
- // will occur.
- void Outfeed(const ComputationDataHandle& operand,
- const Shape& shape_with_layout, const string& outfeed_config);
-
- // Enqueues a call instruction onto the computation.
- ComputationDataHandle Call(
- const Computation& computation,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands);
-
- // Enqueues a custom call instruction onto the computation.
- // During code generation, a call instruction is emitted which targets a
- // symbol with the name |call_target_name|. The |operands| are passed to the
- // call instruction. |shape| is the resultant shape.
- ComputationDataHandle CustomCall(
- const string& call_target_name,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- const Shape& shape);
-
- // Enqueues a pseudo-op to represent host-side computation data-dependencies.
- // During code generation, host send and receive operations will be generated
- // to transfer |operands| to the host and a single result of |shape| back to
- // the device. Host send/recv operations are emitted using |channel_name|.
- // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
- // instruction scheduling.
- ComputationDataHandle HostCompute(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- const string& channel_name, int64 cost_estimate_ns, const Shape& shape);
-
- // The following methods enqueue element-wise binary arithmetic operations
- // onto the computation. The shapes of the operands have to match unless one
- // of the operands is a scalar, or an explicit broadcast dimension is given
- // (see g3doc for more details).
-
- // Enqueues a complex compose instruction onto the computation.
- ComputationDataHandle Complex(
- const ComputationDataHandle& real, const ComputationDataHandle& imag,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a complex conjugate instruction onto the computation.
- ComputationDataHandle Conj(const ComputationDataHandle& operand);
-
- // Enqueues an add instruction onto the computation.
- ComputationDataHandle Add(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a subtract instruction onto the computation.
- ComputationDataHandle Sub(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a multiply instruction onto the computation.
- ComputationDataHandle Mul(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a divide instruction onto the computation.
- ComputationDataHandle Div(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a remainder instruction onto the computation.
- ComputationDataHandle Rem(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a max instruction onto the computation.
- ComputationDataHandle Max(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues a min instruction onto the computation.
- ComputationDataHandle Min(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Element-wise logical operators
- ComputationDataHandle And(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- ComputationDataHandle Or(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- ComputationDataHandle Xor(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- ComputationDataHandle Not(const ComputationDataHandle& operand);
-
- ComputationDataHandle ShiftLeft(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
- ComputationDataHandle ShiftRightArithmetic(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
- ComputationDataHandle ShiftRightLogical(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Reduces an array among the provided dimensions, given "computation" as a
- // reduction operator.
- ComputationDataHandle Reduce(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value, const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
-
- // Convenience wrapper around the above that reduces all the dimensions in the
- // operand shape.
- ComputationDataHandle ReduceAll(const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value,
- const Computation& computation);
-
- // Enqueues a windowed reduce instruction onto the computation.
- ComputationDataHandle ReduceWindow(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value, const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
-
- // As ReduceWindow(), but the padding is given in the format
- // returned by MakePadding().
- ComputationDataHandle ReduceWindowWithGeneralPadding(
- const ComputationDataHandle& operand,
- const ComputationDataHandle& init_value, const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
-
- // Returns the sum of the operand value across all replicas. All replicas
- // supply one input to the sum and all replicas receive the resulting sum.
- ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand);
-
- // Enqueues an operation that scatters the `source` array to the selected
- // indices of each window.
- ComputationDataHandle SelectAndScatter(
- const ComputationDataHandle& operand, const Computation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ComputationDataHandle& source,
- const ComputationDataHandle& init_value, const Computation& scatter);
-
- // As SelectAndScatter(), but the padding is given in the format
- // returned by MakePadding().
- ComputationDataHandle SelectAndScatterWithGeneralPadding(
- const ComputationDataHandle& operand, const Computation& select,
- tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ComputationDataHandle& source,
- const ComputationDataHandle& init_value, const Computation& scatter);
-
- // Enqueues an abs instruction onto the computation.
- ComputationDataHandle Abs(const ComputationDataHandle& operand);
-
- // Enqueues a atan2 instruction onto the computation.
- ComputationDataHandle Atan2(
- const ComputationDataHandle& y, const ComputationDataHandle& x,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues an exp instruction onto the computation.
- ComputationDataHandle Exp(const ComputationDataHandle& operand);
-
- // Enqueues a floor instruction onto the computation.
- ComputationDataHandle Floor(const ComputationDataHandle& operand);
-
- // Enqueues a ceil instruction onto the computation.
- ComputationDataHandle Ceil(const ComputationDataHandle& operand);
-
- // Enqueues a round instruction onto the computation, rounding to nearest even
- // with half-way cases rounding away from zero.
- ComputationDataHandle Round(const ComputationDataHandle& operand);
-
- // Enqueues an log instruction (natural logarithm) onto the computation.
- ComputationDataHandle Log(const ComputationDataHandle& operand);
-
- // Enqueues a sign instruction onto the computation.
- ComputationDataHandle Sign(const ComputationDataHandle& operand);
-
- // Enqueues a cosine instruction onto the computation.
- ComputationDataHandle Cos(const ComputationDataHandle& operand);
-
- // Enqueues a sine instruction onto the computation.
- ComputationDataHandle Sin(const ComputationDataHandle& operand);
-
- // Enqueues a tanh instruction onto the computation.
- ComputationDataHandle Tanh(const ComputationDataHandle& operand);
-
- // Enqueues a real-part instruction onto the computation.
- ComputationDataHandle Real(const ComputationDataHandle& operand);
-
- // Enqueues an imaginary-part instruction onto the computation.
- ComputationDataHandle Imag(const ComputationDataHandle& operand);
-
- // Enqueues a float32 sqrt instruction onto the computation.
- // (float32 is specified as there is an implicit float32 0.5f constant
- // exponent).
- ComputationDataHandle SqrtF32(const ComputationDataHandle& operand);
-
- // Enqueues a float32 square instruction onto the computation.
- // (float32 is specified as there is an implicit float32 2.0f constant
- // exponent).
- ComputationDataHandle SquareF32(const ComputationDataHandle& operand);
-
- // Enqueues a lhs^rhs computation onto the computation.
- ComputationDataHandle Pow(
- const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
-
- // Enqueues an operator that tests if the operand's values are finite, i.e.,
- // not Inf or NaN. Defined only for floating-point types. Returns an array of
- // booleans with the same shape where entries are true iff the corresponding
- // entry was NaN.
- ComputationDataHandle IsFinite(const ComputationDataHandle& operand);
-
- // Enqueues a convert instruction onto the computation that changes the
- // element type of the operand array to primitive_type.
- ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
- PrimitiveType new_element_type);
-
- // Enqueues a no-op instruction onto the computation that changes
- // the element type of the operand array to primitive_type. The
- // bit-widths of the source and destination element types must be
- // identical.
- ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand,
- PrimitiveType new_element_type);
-
- // Enqueues a float32 reciprocal instruction onto the computation.
- // (float32 is specified as there is an implicit float32 -1.0f constant
- // exponent).
- //
- // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
- // shape of the operand.
- ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand);
-
- // Enqueues a negate instruction onto the computation.
- ComputationDataHandle Neg(const ComputationDataHandle& operand);
-
- // Enqueues a count-leading-zeros instruction onto the computation.
- ComputationDataHandle Clz(const ComputationDataHandle& operand);
-
- // Enqueues a transpose instruction onto the computation.
- ComputationDataHandle Transpose(
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> permutation);
-
- // Enqueues a reverse instruction onto the computation. The order of the
- // elements in the given dimensions is reversed (i.e., the element at index i
- // is moved to index dimension_size - 1 - i).
- ComputationDataHandle Rev(const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<int64> dimensions);
-
- // Enqueues a sort (as increasing order) instruction onto the computation.
- ComputationDataHandle Sort(const ComputationDataHandle& operand);
-
- // Enqueues a clamp instruction onto the computation.
- ComputationDataHandle Clamp(const ComputationDataHandle& min,
- const ComputationDataHandle& operand,
- const ComputationDataHandle& max);
-
- // Enqueues a map instruction onto the computation.
- ComputationDataHandle Map(
- tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
- const Computation& computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands = {});
-
- // Enqueues a N(mu, sigma) random number generation instruction onto the
- // computation.
- ComputationDataHandle RngNormal(const ComputationDataHandle& mu,
- const ComputationDataHandle& sigma,
- const Shape& shape);
-
- // Enqueues a U(a, b) random number generation instruction onto the
- // computation. Returns values in the semi-open interval [a, b).
- ComputationDataHandle RngUniform(const ComputationDataHandle& a,
- const ComputationDataHandle& b,
- const Shape& shape);
-
- // Enqueues a while node onto the computation.
- ComputationDataHandle While(const Computation& condition,
- const Computation& body,
- const ComputationDataHandle& init);
-
- // Enqueues a conditional node onto the computation.
- ComputationDataHandle Conditional(const ComputationDataHandle& predicate,
- const ComputationDataHandle& true_operand,
- const Computation& true_computation,
- const ComputationDataHandle& false_operand,
- const Computation& false_computation);
-
- // Enqueues a ReducePrecision node onto the computation.
- ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand,
- const int exponent_bits,
- const int mantissa_bits);
-
- // Enqueues a Gather node onto the computation.
- ComputationDataHandle Gather(
- const ComputationDataHandle& input,
- const ComputationDataHandle& gather_indices,
- const GatherDimensionNumbers& dimension_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds);
-
- // Enqueues a Send node onto the computation, to send the given operand to
- // a Recv instruction that shares the same channel handle.
- void Send(const ComputationDataHandle& operand, const ChannelHandle& handle);
-
- // Enqueues a Recv node onto the computation. The data comes from a Send
- // instruction that shares the same channel handle and its shape must
- // be the same as the given shape.
- ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle);
-
- // Returns true if 'operand' is a compile-time constant. A compile-time
- // constant does not depend on parameters with index greater than or equal to
- // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`.
- // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a
- // compile-time constant without evaluating the computation.
- StatusOr<bool> IsConstant(const ComputationDataHandle& operand,
- int64 num_parameters = 0);
-
- // Normalizes operand across spatial and batch dimensions for each feature.
- //
- // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
- // is the normalized result and batch_mean and batch_var are the mean and
- // variance, respectively, across batch for the operand.
- ComputationDataHandle BatchNormTraining(const ComputationDataHandle& operand,
- const ComputationDataHandle& scale,
- const ComputationDataHandle& offset,
- float epsilon, int64 feature_index);
-
- // Normalizes operand across spatial and batch dimensions for each feature.
- //
- // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
- // computing `mean` and `variance` for each batch inside the operation. It
- // uses the input `mean` and `variance` instead as estimated values. The
- // purpose of this op is to reduce latency in inference, hence the name
- // `BatchNormInference`.
- //
- // The output has the same shape as `operand`, and contains the normalized
- // values for each batch.
- ComputationDataHandle BatchNormInference(
- const ComputationDataHandle& operand, const ComputationDataHandle& scale,
- const ComputationDataHandle& offset, const ComputationDataHandle& mean,
- const ComputationDataHandle& variance, float epsilon,
- int64 feature_index);
-
- // Calculates the gradients of a batch norm op.
- //
- // The inputs `batch_mean` and `batch_var` represent the mean and variance
- // across the batch.
- //
- // Returns a tuple of three elements:
- // - grad_operand: Gradient with respect to input `operand`
- // - grad_offset: Gradient with respect to input `offset`
- // - grad_scale: Gradient with respect to input `scale`
- ComputationDataHandle BatchNormGrad(const ComputationDataHandle& operand,
- const ComputationDataHandle& scale,
- const ComputationDataHandle& batch_mean,
- const ComputationDataHandle& batch_var,
- const ComputationDataHandle& grad_output,
- float epsilon, int64 feature_index);
-
- // Computes the value of a constant indicated by a
- // ComputationDataHandle using a non-optimized interpreter on the host.
- //
- // The operand must be from the computation currently being built -
- // i.e., returned from this builder with no intervening call to
- // Build(). This happens to currently work regardless of that, but
- // that may stop working at any time.
- //
- // The operand must represent a constant value, which in this case
- // means that it must not statically depend on any parameter of the
- // computation that is being built other then the ones specified on the
- // parameter list. The parameters in the list will be indexed by their
- // parameter id property so the number of parameters specified should be at
- // least as many as the largest used parameter index.
- //
- // `IsConstant` can be used to test whether a computation is a compile-time
- // constant without evaluation it. `ComputeConstant` only succeeds for
- // computations where `IsConstant` returns true.
- //
- // This functionality can be useful when translating a computation
- // into XLA where something that looked dynamic is required by
- // XLA to be specified as a constant. E.g. the source
- // computation (outside of XLA) may include a dynamic
- // computation of the shape of something and ComputeConstant lets
- // you determine what the value of that computation is in the case
- // where the value can be determined at compile time.
- //
- // If output_layout is non-null, then the output of the computation
- // will be stored using that layout.
- StatusOr<std::unique_ptr<Literal>> ComputeConstant(
- const ComputationDataHandle& operand,
- const Layout* output_layout = nullptr,
- tensorflow::gtl::ArraySlice<Literal> parameters = {});
-
- // Returns a new ComputationBuilder whose resultant Computation is used only
- // by this ComputationBuilder. The sub-ComputationBuilder has the same
- // die_immediately_on_error behavior as the parent.
- std::unique_ptr<ComputationBuilder> CreateSubBuilder(
- const string& computation_name);
-
- // Modifies the computation being built so that executions of it
- // will return the value associated with operand, rather than the
- // last expression enqueued on the ComputationBuilder. Any subsequent
- // operations added to the ComputationBuilder will not have any effect unless
- // SetReturnValue is called again.
- Status SetReturnValue(const ComputationDataHandle& operand);
-
- // Builds the computation with the requested operations, or returns a non-ok
- // status.
- StatusOr<Computation> Build();
-
- // Builds the computation with the requested operations, or notes an error in
- // the parent ComputationBuilder and returns an empty computation if building
- // failed. This function is intended to be used where the returned
- // Computation is only used by the parent ComputationBuilder and hence further
- // operation on the returned Computation will simply be error'ed out if an
- // error occurred while building this computation. If the built computation is
- // to be used by a ComputationBuilder other than the parent ComputationBuilder
- // then Build() should be used instead.
- Computation BuildAndNoteError();
-
- // Returns the first error that was encountered while building the
- // computation. When an error is encountered, by default we return a vacuous
- // ComputationDataHandle and inform the user of the error that occurred while
- // building the computation when they make a final call to Build().
- //
- // See also set_die_immediately_on_error().
- Status first_error() const { return first_error_; }
-
- private:
- // Limited checking of convolution parameters. Returns false on
- // error.
- bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape,
- const ConvolutionDimensionNumbers& dimension_numbers);
-
- // The parent ComputationBuilder of a sub-ComputationBuilder. The
- // parent_builder_ will be the nullptr if not a sub-ComputationBuilder.
- ComputationBuilder* parent_builder_{nullptr};
-
- // Helper function for creating a Window proto from user-supplied
- // data. Returns true if the user-supplied data was valid.
- bool MakeWindow(tensorflow::gtl::ArraySlice<int64> window_dimensions,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- tensorflow::gtl::ArraySlice<int64> lhs_dilation,
- tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- Window* window);
-
- // Internal helper method that does the building for an arbitrary unary op.
- ComputationDataHandle UnaryOp(UnaryOperation unop,
- const ComputationDataHandle& operand);
-
- // Internal helper method that does the building for an arbitrary binary op.
- // broadcast_dimensions specifies which dimensions to use for broadcasting
- // when the operation is between tensors of different ranks.
- ComputationDataHandle BinaryOp(
- BinaryOperation binop, const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
-
- // Internal helper method that does the building for an arbitrary ternary op.
- ComputationDataHandle TernaryOp(TernaryOperation triop,
- const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs,
- const ComputationDataHandle& ehs);
-
- // Internal helper method that does the building for a random number generator
- // of a given distribution with an explicitly specified shape.
- ComputationDataHandle RngOp(
- RandomDistribution distribution,
- tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,
- const Shape& shape);
-
- // Populates computation_ with a valid object or returns a failing status.
- // This is used before any given operation is enqueued.
- Status PrepareComputation();
-
- // Notes that the error occurred by:
- // * storing it internally and capturing a backtrace if it's the first error
- // (this deferred value will be produced on the call to Build())
- // * dying if die_immediately_on_error_ is true
- void NoteError(const Status& error);
-
- // Helper function that runs the given op_request, filling in op_response.
- // Before the op is run, PrepareComputation is called, and common fields in
- // the op_request are filled in.
- Status RunOp(OpRequest* op_request, OpResponse* op_response);
-
- // Helper function that calls RunOp and calls NoteError on failures.
- void RunOpAndNoteError(OpRequest* op_request);
-
- // Helper function that calls RunOp and either returns the output computation
- // data handle (on success) or a vacuous computation data handle (on failure).
- ComputationDataHandle RunOpAndParseResponse(OpRequest* op_request);
-
- // Helper function that implements GetShape without noting errors. This makes
- // it easier to ensure the real GetShape will note errors on every error path.
- StatusOr<std::unique_ptr<Shape>> GetShapeWithoutNoteError(
- const ComputationDataHandle& operand);
-
- string name_; // Name to use for the built computation.
-
- // The first error encountered while building the computation.
- // This is OK until the first error is encountered.
- Status first_error_;
-
- // The saved stack trace from the point at which the first error occurred.
- tensorflow::SavedStackTrace first_error_backtrace_;
-
- // The computation that operations are enqueued onto.
- Computation computation_;
-
- // The client that the computation is created in. Not owned.
- Client* client_;
-
- // Mode bit that indicates whether to die when a first error is encountered.
- bool die_immediately_on_error_ = false;
-
- // The metadata to attach to each op. This is structured as a "modal"-like
- // operation, in order to simplify client code (and not sprinkle this metadata
- // throughout the TensorFlow op kernel implementations).
- OpMetadata metadata_;
-
- // Sharding for this operator. This is structured as a "model"-like operation,
- // in order to simplify client code, similar to metadata_.
- tensorflow::gtl::optional<OpSharding> sharding_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder);
-};
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) {
- return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR1(
- tensorflow::gtl::ArraySlice<NativeT> values) {
- return ConstantLiteral(*Literal::CreateR1<NativeT>(values));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR1(int64 length,
- NativeT value) {
- Literal literal(ShapeUtil::MakeShape(
- primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
- literal.PopulateWithValue(value);
- return ConstantLiteral(literal);
-}
-
-inline ComputationDataHandle ComputationBuilder::ConstantR1(
- const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(*Literal::CreateR1(values));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR2(
- std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(*Literal::CreateR2<NativeT>(values));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout(
- const Array<NativeT>& values, const Layout& layout) {
- return ConstantLiteral(
- *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantFromArray(
- const Array<NativeT>& values) {
- return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout) {
- return ConstantLiteral(
- *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
- const Array2D<NativeT>& values) {
- return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout(
- const Array3D<NativeT>& values, const Layout& layout) {
- return ConstantLiteral(
- *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D(
- const Array3D<NativeT>& values) {
- return ConstantFromArray(values);
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout(
- const Array4D<NativeT>& values, const Layout& layout) {
- return ConstantFromArrayWithLayout(values, layout);
-}
-
-template <typename NativeT>
-ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
- const Array4D<NativeT>& values) {
- return ConstantFromArray(values);
-}
-
-// RAII-style object: sets the current sharding assignment in builder on
-// construction, and sets back to the previous assignment on destruction.
-class ScopedShardingAssignment {
- public:
- ScopedShardingAssignment(xla::ComputationBuilder* builder,
- tensorflow::gtl::optional<OpSharding> sharding)
- : builder_(builder), prev_sharding_(builder->sharding()) {
- SetSharding(sharding);
- }
-
- ~ScopedShardingAssignment() { SetSharding(prev_sharding_); }
-
- private:
- void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) {
- if (sharding.has_value()) {
- builder_->SetSharding(sharding.value());
- } else {
- builder_->ClearSharding();
- }
- }
-
- xla::ComputationBuilder* const builder_;
- tensorflow::gtl::optional<OpSharding> prev_sharding_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment);
-};
-
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
diff --git a/tensorflow/compiler/xla/client/global_data.cc b/tensorflow/compiler/xla/client/global_data.cc
index 40f59eaa68..2986d40600 100644
--- a/tensorflow/compiler/xla/client/global_data.cc
+++ b/tensorflow/compiler/xla/client/global_data.cc
@@ -31,7 +31,7 @@ GlobalData::~GlobalData() {
*request.mutable_data() = handle_;
UnregisterResponse response;
VLOG(1) << "requesting to unregister " << handle_.ShortDebugString();
- tensorflow::Status s = parent_->Unregister(&request, &response);
+ Status s = parent_->Unregister(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 1acc6f8686..9d44d3ad7d 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -48,7 +48,7 @@ LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
<< "Must have a valid device ordinal that the executable was built for.";
}
-tensorflow::Status LocalExecutable::ValidateExecutionOptions(
+Status LocalExecutable::ValidateExecutionOptions(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableRunOptions& run_options, const Backend& backend) {
const ComputationLayout& host_computation_layout =
@@ -207,7 +207,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
return std::move(result);
}
-tensorflow::Status LocalExecutable::RecordArguments(
+Status LocalExecutable::RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
SessionModule* session_module) {
session_module->clear_arguments();
@@ -219,8 +219,8 @@ tensorflow::Status LocalExecutable::RecordArguments(
return Status::OK();
}
-tensorflow::Status LocalExecutable::RecordResult(
- const ShapedBuffer* result, SessionModule* session_module) {
+Status LocalExecutable::RecordResult(const ShapedBuffer* result,
+ SessionModule* session_module) {
session_module->clear_result();
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
LiteralFromShapedBuffer(*result));
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index d8fd7a5623..31950377f4 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -59,7 +59,7 @@ class LocalExecutable {
// Validates that the given arguments and options satisfy various constraints
// of the computation.
- tensorflow::Status ValidateExecutionOptions(
+ Status ValidateExecutionOptions(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableRunOptions& run_options, const Backend& backend);
@@ -71,13 +71,13 @@ class LocalExecutable {
// Records the arguments used to invoke the computation in a SessionModule
// proto.
- tensorflow::Status RecordArguments(
+ Status RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
SessionModule* session_module);
// Records the result of the computation in a SessionModule proto.
- tensorflow::Status RecordResult(const ShapedBuffer* result,
- SessionModule* session_module);
+ Status RecordResult(const ShapedBuffer* result,
+ SessionModule* session_module);
// Returns a literal containing the contents of the given ShapedBuffer.
StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 1899983e44..2c6b6c60bb 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -437,7 +437,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs,
return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions);
}
-XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) {
+XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = literal.shape();
@@ -1173,6 +1173,10 @@ XlaOp XlaBuilder::Exp(const XlaOp& operand) {
return UnaryOp(HloOpcode::kExp, operand);
}
+XlaOp XlaBuilder::Expm1(const XlaOp& operand) {
+ return UnaryOp(HloOpcode::kExpm1, operand);
+}
+
XlaOp XlaBuilder::Floor(const XlaOp& operand) {
return UnaryOp(HloOpcode::kFloor, operand);
}
@@ -1189,6 +1193,10 @@ XlaOp XlaBuilder::Log(const XlaOp& operand) {
return UnaryOp(HloOpcode::kLog, operand);
}
+XlaOp XlaBuilder::Log1p(const XlaOp& operand) {
+ return UnaryOp(HloOpcode::kLog1p, operand);
+}
+
XlaOp XlaBuilder::Sign(const XlaOp& operand) {
return UnaryOp(HloOpcode::kSign, operand);
}
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index 4955f1515d..e5807033d3 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -139,7 +139,7 @@ class XlaBuilder {
// Enqueues a constant with the value of the given literal onto the
// computation.
- XlaOp ConstantLiteral(const Literal& literal);
+ XlaOp ConstantLiteral(const LiteralSlice& literal);
// Enqueues a constant onto the computation. Methods are templated on the
// native host type (NativeT) which corresponds to a specific XLA
@@ -571,6 +571,9 @@ class XlaBuilder {
// Enqueues an exp instruction onto the computation.
XlaOp Exp(const XlaOp& operand);
+ // Enqueues an expm1 instruction onto the computation.
+ XlaOp Expm1(const XlaOp& operand);
+
// Enqueues a floor instruction onto the computation.
XlaOp Floor(const XlaOp& operand);
@@ -584,6 +587,9 @@ class XlaBuilder {
// Enqueues an log instruction (natural logarithm) onto the computation.
XlaOp Log(const XlaOp& operand);
+ // Enqueues an log1p instruction (log(x+1)) onto the computation.
+ XlaOp Log1p(const XlaOp& operand);
+
// Enqueues a sign instruction onto the computation.
XlaOp Sign(const XlaOp& operand);
diff --git a/tensorflow/compiler/xla/error_spec.h b/tensorflow/compiler/xla/error_spec.h
new file mode 100644
index 0000000000..a1463aa159
--- /dev/null
+++ b/tensorflow/compiler/xla/error_spec.h
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_
+#define TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_
+
+namespace xla {
+
+// Structure describing permissible absolute and relative error bounds.
+struct ErrorSpec {
+ explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false)
+ : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {}
+
+ float abs; // Absolute error bound.
+ float rel; // Relative error bound.
+
+ // If relaxed_nans is true then any result is valid if we are expecting NaNs.
+ // In effect, this allows the tested operation to produce incorrect results
+ // for inputs outside its mathematical domain.
+ bool relaxed_nans;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index c6f8f6766e..a76fdcda25 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -140,8 +140,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
LayoutUtil::SetToDefaultLayout(program_shape->mutable_result());
}
-/* static */ tensorflow::Status LayoutUtil::ValidateLayoutInShape(
- const Shape& shape) {
+/* static */ Status LayoutUtil::ValidateLayoutInShape(const Shape& shape) {
if (ShapeUtil::IsTuple(shape)) {
// Tuple shape.
if (shape.has_layout()) {
@@ -150,12 +149,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
for (auto& element_shape : shape.tuple_shapes()) {
TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape));
}
- return tensorflow::Status::OK();
+ return Status::OK();
} else if (ShapeUtil::IsOpaque(shape)) {
if (shape.has_layout()) {
return InvalidArgument("opaque should not have a layout field");
}
- return tensorflow::Status::OK();
+ return Status::OK();
} else {
// Array shape.
if (!shape.has_layout()) {
@@ -166,14 +165,14 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
}
-/* static */ tensorflow::Status LayoutUtil::ValidateLayoutForShape(
- const Layout& layout, const Shape& shape) {
+/* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout,
+ const Shape& shape) {
if (ShapeUtil::IsTuple(shape)) {
return InvalidArgument("a single Layout is not valid for tuple shapes");
}
if (ShapeUtil::IsOpaque(shape)) {
- return tensorflow::Status::OK();
+ return Status::OK();
}
if (layout.format() == INVALID_FORMAT) {
@@ -225,7 +224,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
/* static */ void LayoutUtil::ClearLayout(Shape* shape) {
@@ -384,7 +383,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
namespace {
// Internal helper for recursively copying layouts.
-tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) {
+Status CopyLayoutInternal(const Shape& src, Shape* dst) {
if (ShapeUtil::IsTuple(src) != ShapeUtil::IsTuple(*dst)) {
return InvalidArgument(
"cannot copy layout from shape: shape structure differs");
@@ -411,14 +410,13 @@ tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) {
dst->clear_layout();
}
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
} // namespace
/* static */
-tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src,
- Shape* dst) {
+Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
return CopyLayoutInternal(src, dst);
}
diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h
index 6cec750101..d3d6a2cc94 100644
--- a/tensorflow/compiler/xla/layout_util.h
+++ b/tensorflow/compiler/xla/layout_util.h
@@ -20,9 +20,9 @@ limitations under the License.
#include <string>
+#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -61,12 +61,12 @@ class LayoutUtil {
static void SetToDefaultLayout(ProgramShape* program_shape);
// Validates that the layout within the given shape is correct.
- static tensorflow::Status ValidateLayoutInShape(const Shape& shape);
+ static Status ValidateLayoutInShape(const Shape& shape);
// Validates that the provided layout satisfies invariants for the given
// shape.
- static tensorflow::Status ValidateLayoutForShape(const Layout& layout,
- const Shape& shape);
+ static Status ValidateLayoutForShape(const Layout& layout,
+ const Shape& shape);
// Clears the layout in the given Shape. After this function is called,
// HasLayout will return false for the shape.
@@ -179,8 +179,7 @@ class LayoutUtil {
// tuples. 'src' and 'dst' need not be compatible but the two shapes must
// have the same tuple structure (if any) and arrays must have the same
// rank. within the shapes must have the same number of dimensions.
- static tensorflow::Status CopyLayoutBetweenShapes(const Shape& src,
- Shape* dst);
+ static Status CopyLayoutBetweenShapes(const Shape& src, Shape* dst);
// Returns true if the layouts of lhs and rhs are equal, false
// otherwise. Recursively compares layouts of tuples.
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index bc8405703b..f42fb92359 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -47,6 +47,12 @@ void SetDebugOptionsDefaults(DebugOptions* flags) {
// Set cudnn batchnorm off by default; it does not provide a performance win
// on average.
flags->set_xla_gpu_use_cudnn_batchnorm(false);
+
+ // Run all GPU work on one stream by default. Using multiple streams
+ // increases memory usage and we lack strong motivating benchmarks for tuning
+ // the heuristics needed to decide when to run on multiple streams. See
+ // b/77879207.
+ flags->set_xla_gpu_disable_multi_streaming(true);
}
// Allocates flag_values and flag_objects; this function must not be called more
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
new file mode 100644
index 0000000000..3696fdbe12
--- /dev/null
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -0,0 +1,739 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/literal_comparison.h"
+
+#include <unistd.h>
+#include <cmath>
+#include <vector>
+
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+
+using tensorflow::strings::Appendf;
+using tensorflow::strings::Printf;
+using tensorflow::strings::StrAppend;
+using tensorflow::strings::StrCat;
+
+namespace xla {
+namespace literal_comparison {
+namespace {
+
+// Helper function for comparing a floating point type, FloatT, bitwise equal
+// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
+// -- on miscompare, a nice error message is given in the AssertionFailure.
+template <typename FloatT, typename UnsignedT>
+Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
+ auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
+ auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
+ auto lhs_double = static_cast<double>(lhs);
+ auto rhs_double = static_cast<double>(rhs);
+ if (ulhs != urhs) {
+ return InvalidArgument(
+ "floating values are not bitwise-equal; and equality testing "
+ "was requested: %s=%g=%a vs %s=%g=%a",
+ StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double,
+ StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double);
+ }
+ return Status::OK();
+}
+
+// Templated comparator that specializes for float equality comparison with the
+// bitwise helper above (this is the un-specialized fallback, to just use the
+// default gunit implementation).
+template <typename NativeT>
+Status CompareEqual(NativeT lhs, NativeT rhs) {
+ if (lhs == rhs) {
+ return Status::OK();
+ }
+ return InvalidArgument("Expected equality of these values:\n %s\n %s",
+ StrCat(lhs).c_str(), StrCat(rhs).c_str());
+}
+
+// Specializations for floating types that do bitwise comparisons when equality
+// comparison is requested.
+template <>
+Status CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
+ return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
+}
+template <>
+Status CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs) {
+ return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs);
+}
+template <>
+Status CompareEqual<float>(float lhs, float rhs) {
+ return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
+}
+template <>
+Status CompareEqual<double>(double lhs, double rhs) {
+ return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
+}
+template <>
+Status CompareEqual<complex64>(complex64 lhs, complex64 rhs) {
+ auto res = CompareEqual<float>(lhs.real(), rhs.real());
+ if (!res.ok()) {
+ return res;
+ }
+ return CompareEqual<float>(lhs.imag(), rhs.imag());
+}
+
+// A recursive function which iterates through every index of expected and
+// actual literal and compares their values elementwise. Returns true if all
+// elements are equal.
+template <typename NativeT>
+Status Equal(LiteralSlice expected, LiteralSlice actual,
+ tensorflow::gtl::MutableArraySlice<int64> multi_index,
+ int64 dimension) {
+ if (dimension == expected.shape().dimensions_size()) {
+ NativeT expected_value = expected.Get<NativeT>(multi_index);
+ NativeT actual_value = actual.Get<NativeT>(multi_index);
+ return CompareEqual<NativeT>(expected_value, actual_value);
+ }
+
+ Status result;
+ for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
+ multi_index[dimension] = i;
+ result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1));
+ }
+ return result;
+}
+
+// Gets the total element count. For tuples, this is not the count of tuple
+// elements, but the sum of elements of each tuple element.
+int64 RecursiveElementCount(const Shape& shape) {
+ if (ShapeUtil::IsTuple(shape)) {
+ const int64 tuple_elements = ShapeUtil::TupleElementCount(shape);
+ int64 total = 0;
+ for (int64 i = 0; i < tuple_elements; ++i) {
+ total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i));
+ }
+ return total;
+ } else {
+ return ShapeUtil::ElementsIn(shape);
+ }
+}
+
+// Returns whether the actual and expected values are mismatched with respect to
+// nans. 'relaxed_nans' is interpreted as in xla::ErrorSpec.
+template <typename NativeT>
+bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) {
+ if (relaxed_nans) {
+ return !std::isnan(expected) && std::isnan(actual);
+ } else {
+ return std::isnan(expected) != std::isnan(actual);
+ }
+}
+
+template <>
+bool NanMismatch<complex64>(complex64 expected, complex64 actual,
+ bool relaxed_nans) {
+ return NanMismatch<float>(expected.real(), actual.real(), relaxed_nans) ||
+ NanMismatch<float>(expected.imag(), actual.imag(), relaxed_nans);
+}
+
+template <>
+bool NanMismatch<half>(half expected, half actual, bool relaxed_nans) {
+ return NanMismatch<float>(static_cast<float>(expected),
+ static_cast<float>(actual), relaxed_nans);
+}
+
+// Converts the given floating-point value to a string.
+template <typename NativeT>
+string FpValueToString(NativeT value) {
+ return Printf("%8.4g", static_cast<double>(value));
+}
+
+template <>
+string FpValueToString<complex64>(complex64 value) {
+ return Printf("%8.4g + %8.4fi", value.real(), value.imag());
+}
+
+// Returns the absolute value of the given floating point value. This function
+// is used instead of std::abs directly in order to allow type-dependent
+// implementations for NearComparator.
+template <typename NativeT>
+float FpAbsoluteValue(NativeT value) {
+ return std::abs(value);
+}
+
+template <>
+float FpAbsoluteValue(bfloat16 value) {
+ return FpAbsoluteValue<float>(static_cast<float>(value));
+}
+
+template <>
+float FpAbsoluteValue(half value) {
+ return FpAbsoluteValue<float>(static_cast<float>(value));
+}
+
+// Helper class for comparing floating-point literals within an error bound.
+template <typename NativeT>
+class NearComparator {
+ public:
+ // Compares the two array literals elementwise and returns a comparison
+ // result. The comparison is ok() if all actual and expected elements are
+ // within the given error bound. In case of error, the status contains a
+ // detailed message about the discrepancy.
+ static Status Compare(const LiteralSlice& expected,
+ const LiteralSlice& actual, ErrorSpec error,
+ bool detailed_message,
+ const MiscompareCallback& miscompare_callback) {
+ NearComparator<NativeT> comparator(expected, actual, error,
+ detailed_message, miscompare_callback);
+ return comparator.Run();
+ }
+
+ private:
+ // Data structure encapsulating metadata about a single element mismatch.
+ struct Mismatch {
+ NativeT actual;
+ NativeT expected;
+ float rel_error;
+ float abs_error;
+
+ // The linear index of the failure within the shape. This linear index is
+ // from the 'actual' literal.
+ int64 linear_index;
+
+ bool operator<(const Mismatch& other) const {
+ return rel_error < other.rel_error;
+ }
+
+ string ToString(const Shape& shape) const {
+ return Printf(
+ "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
+ FpValueToString(actual).c_str(), FpValueToString(expected).c_str(),
+ Literal::MultiIndexAsString(
+ IndexUtil::LinearIndexToMultidimensionalIndex(shape,
+ linear_index))
+ .c_str(),
+ rel_error, abs_error);
+ }
+ };
+
+ NearComparator(const LiteralSlice& expected, const LiteralSlice& actual,
+ ErrorSpec error, bool detailed_message,
+ const MiscompareCallback& miscompare_callback)
+ : expected_(expected),
+ actual_(actual),
+ error_(error),
+ detailed_message_(detailed_message),
+ miscompare_callback_(miscompare_callback),
+ abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}),
+ abs_error_buckets_(kErrorBucketBounds.size(), 0),
+ rel_error_buckets_(kErrorBucketBounds.size(), 0) {}
+
+ // Runs the comparison between expected and actual literals.
+ Status Run() {
+ VLOG(1) << "expected:";
+ XLA_VLOG_LINES(1, ToStringTruncated(expected_));
+ VLOG(1) << "actual:";
+ XLA_VLOG_LINES(1, ToStringTruncated(actual_));
+
+ // If the shapes mismatch, we simply fail the expectation instead of
+ // printing out data, as it's a type error rather than a value error.
+ TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape()));
+ if (!ShapeUtil::IsArray(expected_.shape())) {
+ return InvalidArgument("Expected array shape; got %s.",
+ ShapeUtil::HumanString(expected_.shape()).c_str());
+ }
+
+ mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED));
+ mismatches_.PopulateWithValue(false);
+
+ CompareLiterals();
+
+ if (num_mismatches_ == 0) {
+ return Status::OK();
+ } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) {
+ miscompare_callback_(expected_, actual_, mismatches_);
+ }
+ return InvalidArgument("%s", ErrorMessage().c_str());
+ }
+
+ // Insert the given absolute value into the absolute value bucket vector. The
+ // bounds of the buckets are given by kAbsValueBucketBounds.
+ void UpdateAbsValueBucket(NativeT value, bool is_mismatch) {
+ // Adjust the bucket containing the absolute values of the 'actual'
+ // elements.
+ const float abs_value = FpAbsoluteValue(value);
+ for (int i = 0; i < abs_value_buckets_.size(); ++i) {
+ if (i == abs_value_buckets_.size() - 1 ||
+ (abs_value >= kAbsValueBucketBounds[i] &&
+ abs_value < kAbsValueBucketBounds[i + 1])) {
+ // The first value of the pair is the count of elements in the bucket,
+ // the second is the count of mismatches in the bucket.
+ abs_value_buckets_[i].first++;
+ if (is_mismatch) {
+ abs_value_buckets_[i].second++;
+ }
+ return;
+ }
+ }
+ }
+
+ // Insert the given error into the given error bucket vector.
+ void UpdateErrorBucket(
+ float error, tensorflow::gtl::MutableArraySlice<int64> error_buckets) {
+ CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size());
+ for (int i = 0; i < error_buckets.size(); ++i) {
+ if (error >= kErrorBucketBounds[i]) {
+ error_buckets[i]++;
+ }
+ }
+ }
+
+ // Compares the two given elements from the expected and actual literals at
+ // the given literal_index and keeps track of various mismatch statistics.
+ void CompareValues(NativeT expected, NativeT actual, int64 linear_index) {
+ const bool is_nan_mismatch =
+ NanMismatch(expected, actual, error_.relaxed_nans);
+ float abs_error;
+ float rel_error;
+ if (actual == expected) {
+ abs_error = 0;
+ rel_error = 0;
+ } else if (is_nan_mismatch) {
+ num_nan_mismatches_++;
+ // A nan mismatch is considered to have infinite error. rel_error is used
+ // for sorting a std::set of the top mismatchs, and a nan value here will
+ // result in undefined behavior because nan's do not satisfy the strict
+ // weak ordering requirement of std containers.
+ abs_error = std::numeric_limits<float>::infinity();
+ rel_error = std::numeric_limits<float>::infinity();
+ } else {
+ abs_error = FpAbsoluteValue(actual - expected);
+ rel_error = abs_error / FpAbsoluteValue(expected);
+ }
+ const bool is_abs_mismatch = abs_error > error_.abs;
+ const bool is_rel_mismatch = rel_error > error_.rel;
+ const bool is_mismatch =
+ is_nan_mismatch || (is_abs_mismatch && is_rel_mismatch);
+
+ // Update the error of the relative bucket only if the *absolute* error
+ // bound is exceeded and vice versa.
+ if (is_abs_mismatch) {
+ num_abs_mismatches_++;
+ UpdateErrorBucket(rel_error, &rel_error_buckets_);
+ }
+ if (is_rel_mismatch) {
+ num_rel_mismatches_++;
+ UpdateErrorBucket(abs_error, &abs_error_buckets_);
+ }
+
+ UpdateAbsValueBucket(actual, is_mismatch);
+
+ if (!is_mismatch) {
+ return;
+ }
+
+ num_mismatches_++;
+
+ // Keep track of the kTopRelativeErrorCount relative error mismatches.
+ if (top_rel_mismatches_.size() < kTopRelativeErrorCount ||
+ rel_error > top_rel_mismatches_.begin()->rel_error) {
+ Mismatch mismatch = {actual, expected, rel_error, abs_error,
+ linear_index};
+ top_rel_mismatches_.insert(mismatch);
+ if (top_rel_mismatches_.size() > kTopRelativeErrorCount) {
+ top_rel_mismatches_.erase(top_rel_mismatches_.begin());
+ }
+ }
+
+ mismatches_.data<bool>()[linear_index] = true;
+ }
+
+ // Compares the two literals elementwise.
+ void CompareLiterals() {
+ // Fast path optimization for the case were layouts match.
+ if (LayoutUtil::Equal(actual_.shape().layout(),
+ expected_.shape().layout())) {
+ tensorflow::gtl::ArraySlice<const NativeT> expected_data =
+ expected_.data<NativeT>();
+ tensorflow::gtl::ArraySlice<const NativeT> actual_data =
+ actual_.data<NativeT>();
+ const int64 len = expected_data.size();
+ for (int64 i = 0; i < len; ++i) {
+ CompareValues(expected_data[i], actual_data[i], i);
+ }
+ return;
+ }
+ std::vector<int64> multi_index(ShapeUtil::Rank(actual_.shape()), 0);
+ CompareLiteralsSlow(0, &multi_index);
+ }
+
+ // Slow path for CompareLiterals when 'actual' and 'expected' literals have
+ // different layouts. In this case, multidimensional indices are constructed
+ // and indexed for each element.
+ void CompareLiteralsSlow(int64 dimension, std::vector<int64>* multi_index) {
+ if (dimension == multi_index->size()) {
+ CompareValues(expected_.Get<NativeT>(*multi_index),
+ actual_.Get<NativeT>(*multi_index),
+ IndexUtil::MultidimensionalIndexToLinearIndex(
+ actual_.shape(), *multi_index));
+ } else {
+ for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) {
+ (*multi_index)[dimension] = i;
+ CompareLiteralsSlow(dimension + 1, multi_index);
+ }
+ }
+ }
+
+ // Returns an error message string with a detailed breakdown of the
+ // mismatches. Called after calling Run().
+ string ErrorMessage() {
+ string out;
+ int64 element_count = ShapeUtil::ElementsIn(actual_.shape());
+
+ auto percent_string = [](float a, float b) {
+ float pct = b == 0.0 ? 0.0 : 100.0 * a / b;
+ return Printf("%0.4f%%", pct);
+ };
+
+ Appendf(&out,
+ "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound "
+ "%g, rel bound %g\n",
+ num_mismatches_,
+ percent_string(num_mismatches_, element_count).c_str(),
+ ShapeUtil::HumanString(actual_.shape()).c_str(),
+ ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel);
+ if (num_nan_mismatches_ > 0) {
+ StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n");
+ }
+ Appendf(&out, "Top relative error mismatches:\n");
+ for (auto it = top_rel_mismatches_.rbegin();
+ it != top_rel_mismatches_.rend(); ++it) {
+ StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n");
+ }
+
+ if (!detailed_message_) {
+ return out;
+ }
+
+ StrAppend(&out, "Absolute magnitude breakdown of actual values:\n");
+ CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size());
+ for (int i = 0; i < abs_value_buckets_.size(); ++i) {
+ const int64 bucket_size = abs_value_buckets_[i].first;
+ const int64 bucket_mismatches = abs_value_buckets_[i].second;
+ string mismatch_str = bucket_mismatches > 0
+ ? Printf(", mismatches %lld", bucket_mismatches)
+ : "";
+ Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n",
+ kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1],
+ bucket_size, percent_string(bucket_size, element_count).c_str(),
+ mismatch_str.c_str());
+ }
+
+ auto print_accum_buckets = [&](const string& header, int64 total,
+ tensorflow::gtl::ArraySlice<int64> buckets) {
+ StrAppend(&out, header, ":\n");
+ Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0],
+ total - buckets[0],
+ percent_string(total - buckets[0], total).c_str());
+ CHECK_EQ(buckets.size(), kErrorBucketBounds.size());
+ for (int i = 0; i < kErrorBucketBounds.size(); ++i) {
+ Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i],
+ buckets[i], percent_string(buckets[i], total).c_str());
+ }
+ };
+ Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n",
+ error_.abs, num_abs_mismatches_,
+ percent_string(num_abs_mismatches_, element_count).c_str());
+ print_accum_buckets(
+ "Relative error breakdown of elements exceeding abs error bound",
+ num_abs_mismatches_, rel_error_buckets_);
+ Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n",
+ error_.rel, num_rel_mismatches_,
+ percent_string(num_rel_mismatches_, element_count).c_str());
+ print_accum_buckets(
+ "Absolute error breakdown of elements exceeding rel error bound",
+ num_rel_mismatches_, abs_error_buckets_);
+ return out;
+ }
+
+ // 'actual' and 'expected' literals being compared.
+ LiteralSlice expected_;
+ LiteralSlice actual_;
+
+ // The error bounds of the comparison.
+ ErrorSpec error_;
+
+ // Whether to include detailed breakdown of mismatches in the error message.
+ bool detailed_message_;
+
+ // Callback to invoke on miscompare.
+ MiscompareCallback miscompare_callback_;
+
+ // Number of element element mismatches encountered so far.
+ int64 num_mismatches_ = 0;
+
+ // Number of elements with a nan mismatch.
+ int64 num_nan_mismatches_ = 0;
+
+ // Number of elements which exceed the absolute/relative error bound.
+ int64 num_abs_mismatches_ = 0;
+ int64 num_rel_mismatches_ = 0;
+
+ // A Literal containing which elements did not match in the expected and
+ // actual literals. mismatches_ contains PREDs and is of the same sizes as
+ // the comparison literals.
+ Literal mismatches_;
+
+ // The number of mismatches to report in the output, sorted by relative error
+ // magnitude.
+ static constexpr int64 kTopRelativeErrorCount = 5;
+
+ // The set of mismatches with the largest relative error. The size of this set
+ // is bounded by kTopRelativeErrorCount.
+ std::multiset<Mismatch> top_rel_mismatches_;
+
+ // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the
+ // bounds of these buckets. abs_value_buckets_ contains a pair for each
+ // bucket: the element count and failure count.
+ static constexpr std::array<float, 7> kAbsValueBucketBounds = {
+ 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits<float>::infinity()};
+ std::vector<std::pair<int64, int64>> abs_value_buckets_;
+
+ // Buckets for relative and absolute errors. The relative error buckets only
+ // contains those elements which exceed the *absolute* error bound, and vice
+ // versa. This makes it easy to see the effect of adjusting the relative (or
+ // absolute) error bound on the success of the comparison. kErrorBucketBounds
+ // are the lower bounds of the buckets in both vectors. The error buckets are
+ // a cumulative distribution so an error value may appear in more than one
+ // bucket. For example an error value of 0.003 may appear in the buckets
+ // bounded by 0.01, 0.1, and 1.0.
+ static constexpr std::array<float, 5> kErrorBucketBounds = {0.0001, 0.001,
+ 0.01, 0.1, 1};
+ std::vector<int64> abs_error_buckets_;
+ std::vector<int64> rel_error_buckets_;
+};
+
+template <typename NativeT>
+constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
+template <typename NativeT>
+constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
+
+// Helper function for comparing two literals for nearness. Handles tuple-shapes
+// via recursion. shape_index is the ShapeIndex of expected (or actual)
+// currently being compared.
+Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
+ const ErrorSpec& error, bool detailed_message,
+ const MiscompareCallback& miscompare_callback,
+ const ShapeIndex& shape_index) {
+ TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
+
+ if (ShapeUtil::IsTuple(expected.shape())) {
+ Status return_status;
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
+ const auto expected_element = LiteralSlice(expected, {i});
+ const auto actual_element = LiteralSlice(actual, {i});
+ ShapeIndex element_index = shape_index;
+ element_index.push_back(i);
+ Status res =
+ NearHelper(expected_element, actual_element, error, detailed_message,
+ miscompare_callback, element_index);
+ if (!res.ok()) {
+ string err_message = Printf("\nArray at shape index %s%s",
+ element_index.ToString().c_str(),
+ res.error_message().c_str());
+ if (return_status.ok()) {
+ return_status = res;
+ } else {
+ return_status = AppendStatus(return_status, res.error_message());
+ }
+ }
+ }
+ if (!return_status.ok() && shape_index.empty()) {
+ // Emit a top-level error message containing the top-level shape in case
+ // of mismatch.
+ int64 total_elements = RecursiveElementCount(actual.shape());
+ return_status = InvalidArgument(
+ "\nMismatches in shape %s (%lld elements):\n%s",
+ ShapeUtil::HumanString(actual.shape()).c_str(), total_elements,
+ return_status.error_message().c_str());
+ }
+ return return_status;
+ }
+
+ if (ShapeUtil::ElementIsFloating(expected.shape()) ||
+ ShapeUtil::ElementIsComplex(expected.shape())) {
+ switch (expected.shape().element_type()) {
+ case BF16:
+ return NearComparator<bfloat16>::Compare(
+ expected, actual, error, detailed_message, miscompare_callback);
+ break;
+ case F16:
+ return NearComparator<half>::Compare(
+ expected, actual, error, detailed_message, miscompare_callback);
+ break;
+ case F32:
+ return NearComparator<float>::Compare(
+ expected, actual, error, detailed_message, miscompare_callback);
+ break;
+ case F64:
+ return NearComparator<double>::Compare(
+ expected, actual, error, detailed_message, miscompare_callback);
+ break;
+ case C64:
+ return NearComparator<complex64>::Compare(
+ expected, actual, error, detailed_message, miscompare_callback);
+ break;
+ default:
+ LOG(FATAL) << "Unsupported primitive type in near comparator: "
+ << PrimitiveType_Name(expected.shape().element_type())
+ << ". Must be floating-point type.";
+ }
+ }
+
+ // Non-floating point literal.
+ return literal_comparison::Equal(expected, actual);
+}
+
+} // namespace
+
+Status EqualShapes(const Shape& expected, const Shape& actual) {
+ if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) {
+ return InvalidArgument("tupleness-mismatch! want: %s got %s",
+ ShapeUtil::HumanString(expected).c_str(),
+ ShapeUtil::HumanString(actual).c_str());
+ }
+ if (ShapeUtil::IsTuple(expected)) {
+ if (ShapeUtil::TupleElementCount(expected) !=
+ ShapeUtil::TupleElementCount(actual)) {
+ return InvalidArgument(
+ "want tuple element count: %lld got tuple element count: %lld",
+ ShapeUtil::TupleElementCount(expected),
+ ShapeUtil::TupleElementCount(actual));
+ }
+ for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
+ Status result =
+ EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
+ if (!result.ok()) {
+ return AppendStatus(result, StrCat("mismatch in tuple index", i));
+ }
+ }
+ } else {
+ if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
+ return InvalidArgument("want rank of %s got rank of %s",
+ ShapeUtil::HumanString(expected).c_str(),
+ ShapeUtil::HumanString(actual).c_str());
+ }
+ if (expected.element_type() != actual.element_type()) {
+ return InvalidArgument(
+ "mismatch in primitive type %s vs %s",
+ PrimitiveType_Name(expected.element_type()).c_str(),
+ PrimitiveType_Name(actual.element_type()).c_str());
+ }
+ if (expected.dimensions_size() != actual.dimensions_size()) {
+ return InvalidArgument("want dimensions_size %d got dimensions_size %d",
+ expected.dimensions_size(),
+ actual.dimensions_size());
+ }
+ for (int i = 0; i < expected.dimensions_size(); ++i) {
+ if (expected.dimensions(i) != actual.dimensions(i)) {
+ return InvalidArgument(
+ "mismatch in dimension #%d expected: %s actual: %s", i,
+ ShapeUtil::HumanString(expected).c_str(),
+ ShapeUtil::HumanString(actual).c_str());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
+ VLOG(1) << "expected:";
+ XLA_VLOG_LINES(1, expected.ToString());
+ VLOG(1) << "actual:";
+ XLA_VLOG_LINES(1, actual.ToString());
+
+ TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
+ std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
+ Status result;
+ switch (expected.shape().element_type()) {
+ case PRED:
+ result = Equal<bool>(expected, actual, &multi_index, 0);
+ break;
+ case U8:
+ result = Equal<uint8>(expected, actual, &multi_index, 0);
+ break;
+ case S32:
+ result = Equal<int32>(expected, actual, &multi_index, 0);
+ break;
+ case S64:
+ result = Equal<int64>(expected, actual, &multi_index, 0);
+ break;
+ case U32:
+ result = Equal<uint32>(expected, actual, &multi_index, 0);
+ break;
+ case U64:
+ result = Equal<uint64>(expected, actual, &multi_index, 0);
+ break;
+ case BF16:
+ result = Equal<bfloat16>(expected, actual, &multi_index, 0);
+ break;
+ case F16:
+ result = Equal<half>(expected, actual, &multi_index, 0);
+ break;
+ case F32:
+ result = Equal<float>(expected, actual, &multi_index, 0);
+ break;
+ case F64:
+ result = Equal<double>(expected, actual, &multi_index, 0);
+ break;
+ case C64:
+ result = Equal<complex64>(expected, actual, &multi_index, 0);
+ break;
+ case TUPLE: {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
+ result.Update(
+ Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i})));
+ }
+ break;
+ }
+ default:
+ LOG(FATAL)
+ << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
+ << PrimitiveType_Name(expected.shape().element_type());
+ }
+
+ if (result.ok()) {
+ return Status::OK();
+ }
+
+ return AppendStatus(result,
+ tensorflow::strings::Printf("expected: %s\nactual: %s",
+ expected.ToString().c_str(),
+ actual.ToString().c_str()));
+}
+
+Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
+ const ErrorSpec& error, bool detailed_message,
+ const MiscompareCallback& miscompare_callback) {
+ return NearHelper(expected, actual, error, detailed_message,
+ miscompare_callback,
+ /*shape_index=*/{});
+}
+
+string ToStringTruncated(const LiteralSlice& literal) {
+ return RecursiveElementCount(literal.shape()) < 1000
+ ? literal.ToString()
+ : "[TRUNCATED, Literal with more than 1000 values]";
+}
+
+} // namespace literal_comparison
+} // namespace xla
diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h
new file mode 100644
index 0000000000..00a13e3619
--- /dev/null
+++ b/tensorflow/compiler/xla/literal_comparison.h
@@ -0,0 +1,72 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Library for comparing literals without taking a dependency on testing
+// libraries.
+
+#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
+#define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
+
+#include "tensorflow/compiler/xla/error_spec.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace xla {
+namespace literal_comparison {
+
+// Returns ok if the given shapes have the same rank, dimension sizes, and
+// primitive types.
+Status EqualShapes(const Shape& expected, const Shape& actual);
+
+// Returns ok if the expected and actual literals are (bitwise) equal for all
+// elements in the literal. Also, asserts that the rank, dimensions sizes, and
+// primitive type are equal.
+Status Equal(const LiteralSlice& expected, const LiteralSlice& actual);
+
+using MiscompareCallback =
+ std::function<void(const LiteralSlice& expected, const LiteralSlice& actual,
+ const LiteralSlice& mismatches)>;
+
+// Inspects whether the expected and actual literals are within the given error
+// bound for all elements. Also, inspects whether the rank, dimensions sizes,
+// and dimension bounds are equivalent.
+//
+// Tuples are matched recursively.
+//
+// When comparing tensors of non-floating-point type, this inspects for exact
+// equality, ignoring the ErrorSpec.
+//
+// If the shape of the literals is neither a complex/floating-point tensor nor a
+// tuple which contains a complex/floating-point tensor, Near() is equivalent to
+// Equal(). We don't raise an error in this case, because we want to allow
+// callers to call Near() even if they have no preconceptions about the shapes
+// being compared.
+//
+// If detailed_message is true, then the error message in the assertion result
+// will contain a more detailed breakdown of mismatches.
+Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
+ const ErrorSpec& error, bool detailed_message,
+ const MiscompareCallback& miscompare_callback);
+
+// Calling ToString on a literal with over 100 million elements takes around
+// 3 minutes. The utility of printing a literal with >1000 elements is
+// questionable, especially when writing the Literal proto to disk is orders
+// of magnitude faster.
+string ToStringTruncated(const LiteralSlice& literal);
+
+} // namespace literal_comparison
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index b3b5e34ba2..82a2bcad76 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -62,8 +62,49 @@ void ConvertEndianShort(char* bytes, int64 size) {
}
}
+// Return a literal with all arrays of type FromNativeT converted to type
+// ToNativeT in the given literal.
+template <typename FromNativeT, typename ToNativeT>
+std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
+ // First construct shape of the result.
+ Shape result_shape(literal.shape());
+ ShapeUtil::ForEachMutableSubshape(
+ &result_shape, [](Shape* subshape, const ShapeIndex&) {
+ if (subshape->element_type() ==
+ primitive_util::NativeToPrimitiveType<FromNativeT>()) {
+ subshape->set_element_type(
+ primitive_util::NativeToPrimitiveType<ToNativeT>());
+ }
+ });
+ auto result = MakeUnique<Literal>(result_shape);
+
+ // Then copy over the data from 'literal' converting FromNativeT values to
+ // ToNativeT values as necessary.
+ ShapeUtil::ForEachSubshape(
+ literal.shape(),
+ [&](const Shape& subshape, const ShapeIndex& shape_index) {
+ if (ShapeUtil::IsArray(subshape)) {
+ if (subshape.element_type() ==
+ primitive_util::NativeToPrimitiveType<FromNativeT>()) {
+ auto src = literal.data<FromNativeT>(shape_index);
+ auto dest = result->data<ToNativeT>(shape_index);
+ for (int64 i = 0; i < src.size(); ++i) {
+ dest[i] = static_cast<ToNativeT>(src[i]);
+ }
+ } else {
+ TF_CHECK_OK(result->CopyFrom(literal,
+ /*dest_shape_index=*/shape_index,
+ /*src_shape_index=*/shape_index));
+ }
+ }
+ });
+ return result;
+}
+
} // namespace
+LiteralBase::~LiteralBase() {}
+
std::ostream& operator<<(std::ostream& out, const Literal& literal) {
out << literal.ToString();
return out;
@@ -95,99 +136,90 @@ Literal::StrideConfig::StrideConfig(
Literal::Literal(const Shape& shape)
: Literal(shape, /*allocate_arrays=*/true) {}
-Literal::Literal(const Shape& shape, bool allocate_arrays)
- : shape_(shape), pieces_(shape), owns_buffers_(true) {
- CHECK(LayoutUtil::HasLayout(shape));
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
-
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
- const Shape& subshape = piece.subshape();
- if (ShapeUtil::IsArray(subshape)) {
- if (allocate_arrays) {
- if (LayoutUtil::IsSparseArray(subshape)) {
- // For sparse arrays, the buffer must be of the size of the maximum
- // number of sparse elements possible.
- const int64 max_sparse_elements =
- LayoutUtil::MaxSparseElements(subshape.layout());
- piece.set_buffer(
- new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType(
- subshape.element_type())]);
- piece.set_sparse_indices(new SparseIndexArray(
- max_sparse_elements, ShapeUtil::Rank(subshape)));
- } else {
- piece.set_buffer(new char[piece.size_bytes()]);
- }
+void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
+ if (ShapeUtil::IsTuple(shape)) {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ const Shape& subshape = shape.tuple_shapes(i);
+
+ auto child_piece = Piece();
+ child_piece.set_subshape(&subshape);
+
+ SetPiece(subshape, &child_piece, allocate_arrays);
+
+ piece->emplace_back(std::move(child_piece));
+ }
+ } else {
+ CHECK(ShapeUtil::IsArray(shape));
+ if (allocate_arrays) {
+ if (LayoutUtil::IsSparseArray(shape)) {
+ // For sparse arrays, the buffer must be of the size of the maximum
+ // number of sparse elements possible.
+ const int64 max_sparse_elements =
+ LayoutUtil::MaxSparseElements(shape.layout());
+ piece->set_buffer(
+ new char[max_sparse_elements *
+ ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]);
+ piece->set_sparse_indices(
+ new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape)));
} else {
- piece.set_buffer(nullptr);
+ piece->set_buffer(new char[piece->size_bytes()]);
}
}
}
}
-Literal::~Literal() { DeallocateBuffers(); }
+Literal::Literal(const Shape& shape, bool allocate_arrays)
+ : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ CHECK(LayoutUtil::HasLayout(*shape_));
+ root_piece_ = new Piece();
+ root_piece_->set_subshape(shape_.get());
+ CHECK(&root_piece_->subshape() == shape_.get());
-void Literal::DeallocateBuffers() {
- if (owns_buffers_) {
- for (auto& pair : pieces_) {
- Piece& piece = pair.second;
- if (piece.buffer() != nullptr) {
- delete[] piece.buffer();
- delete piece.sparse_indices();
- }
- }
- }
+ SetPiece(*shape_, root_piece_, allocate_arrays);
}
-Literal::Literal(Literal&& other) {
- shape_ = std::move(other.shape_);
- pieces_ = std::move(other.pieces_);
- // We need to iterate through the pieces to set the subshape pointer
- // properly. It must refer to subshapes within shape_.
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
+Literal::~Literal() {
+ if (root_piece_ != nullptr) {
+ DeallocateBuffers();
+ delete root_piece_;
}
- owns_buffers_ = other.owns_buffers_;
+}
- other.shape_ = ShapeUtil::MakeNil();
- other.pieces_ = ShapeTree<Piece>(other.shape_);
- other.piece({}).set_subshape(&other.shape_);
+void Literal::DeallocateBuffers() {
+ root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (piece->buffer() != nullptr) {
+ delete[] piece->buffer();
+ delete piece->sparse_indices();
+ }
+ });
}
+Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); }
+
Literal& Literal::operator=(Literal&& other) {
- DeallocateBuffers();
- shape_ = std::move(other.shape_);
- pieces_ = std::move(other.pieces_);
- // We need to iterate through the pieces to set the subshape pointer
- // properly. It must refer to subshapes within shape_.
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
- }
- owns_buffers_ = other.owns_buffers_;
-
- other.shape_ = ShapeUtil::MakeNil();
- other.pieces_ = ShapeTree<Piece>(other.shape_);
- other.piece({}).set_subshape(&other.shape_);
+ CHECK(&other.root_piece_->subshape() == other.shape_.get());
+
+ using std::swap;
+ swap(shape_, other.shape_);
+ swap(root_piece_, other.root_piece_);
+ CHECK(&root_piece_->subshape() == shape_.get());
+
return *this;
}
-std::unique_ptr<Literal> Literal::CreateFromShape(const Shape& shape) {
+std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
auto literal = MakeUnique<Literal>(shape);
- for (auto& pair : literal->pieces_) {
- Piece& piece = pair.second;
- if (ShapeUtil::IsArray(piece.subshape())) {
- memset(piece.untyped_data(), 0, piece.size_bytes());
- }
- }
+ literal->root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (ShapeUtil::IsArray(piece->subshape())) {
+ memset(piece->untyped_data(), 0, piece->size_bytes());
+ }
+ });
return literal;
}
-const SparseIndexArray* Literal::sparse_indices(
+const SparseIndexArray* LiteralBase::sparse_indices(
const ShapeIndex& shape_index) const {
return piece(shape_index).sparse_indices();
}
@@ -202,9 +234,19 @@ SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions));
}
+/* static */ std::unique_ptr<Literal> Literal::ConvertBF16ToF32(
+ const LiteralSlice& bf16_literal) {
+ return ConvertType<bfloat16, float>(bf16_literal);
+}
+
+/* static */ std::unique_ptr<Literal> Literal::ConvertF32ToBF16(
+ const LiteralSlice& f32_literal) {
+ return ConvertType<float, bfloat16>(f32_literal);
+}
+
template <typename NativeT>
Status Literal::CopySliceFromInternal(
- const Literal& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
+ const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size) {
TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
@@ -217,8 +259,8 @@ Status Literal::CopySliceFromInternal(
if (ShapeUtil::Rank(src_literal.shape()) == 0 ||
ShapeUtil::Rank(shape()) == 0) {
- // If any of the two shapes are scalars, we can just call the StridedCopy()
- // directly, and we know we will be copying only one value.
+ // If any of the two shapes are scalars, we can just call the
+ // StridedCopy() directly, and we know we will be copying only one value.
TF_RET_CHECK(copy_size.empty());
StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
src_literal.data<NativeT>(),
@@ -264,7 +306,7 @@ Status Literal::CopySliceFromInternal(
return Status::OK();
}
-Status Literal::CopyElementFrom(const Literal& src_literal,
+Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
tensorflow::gtl::ArraySlice<int64> src_index,
tensorflow::gtl::ArraySlice<int64> dest_index) {
DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
@@ -293,22 +335,21 @@ std::vector<Literal> Literal::DecomposeTuple() {
elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
/*allocate_arrays=*/false));
Literal& element = elements.back();
- for (auto& pair : element.pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& dest_piece = pair.second;
- ShapeIndex src_index = {i};
- for (int64 j : index) {
- src_index.push_back(j);
- }
- Piece& src_piece = piece(src_index);
-
- // Move the respective buffer and sparse indices over to the element
- // Literal.
- dest_piece.set_buffer(src_piece.buffer());
- src_piece.set_buffer(nullptr);
- dest_piece.set_sparse_indices(src_piece.sparse_indices());
- src_piece.set_sparse_indices(nullptr);
- }
+ element.root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* dest_piece) {
+ ShapeIndex src_index = {i};
+ for (int64 j : index) {
+ src_index.push_back(j);
+ }
+ Piece& src_piece = piece(src_index);
+
+ // Move the respective buffer and sparse indices over to the element
+ // Literal.
+ dest_piece->set_buffer(src_piece.buffer());
+ src_piece.set_buffer(nullptr);
+ dest_piece->set_sparse_indices(src_piece.sparse_indices());
+ src_piece.set_sparse_indices(nullptr);
+ });
}
// Set this literal to be nil-shaped.
*this = Literal();
@@ -331,9 +372,9 @@ std::vector<Literal> Literal::DecomposeTuple() {
}
namespace {
-
-// Copies the elements in 'src' to 'dest'. The shape and layout of the data in
-// the array slices are indicated by dest_shape and src_shape respectively.
+// Copies the elements in 'src' to 'dest'. The shape and layout of the data
+// in the array slices are indicated by dest_shape and src_shape
+// respectively.
template <typename NativeT>
void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
tensorflow::gtl::ArraySlice<NativeT> src,
@@ -351,7 +392,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
} // namespace
-Status Literal::Piece::CopyFrom(const Literal::Piece& src) {
+Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
if (ShapeUtil::Equal(subshape(), src.subshape())) {
// If the layouts are equal it's faster just to memcpy.
memcpy(buffer(), src.buffer(), src.size_bytes());
@@ -381,14 +422,15 @@ Status Literal::Piece::CopyFrom(const Literal::Piece& src) {
#undef COPY_ELEMENTS
default:
return Unimplemented(
- "Copying a Literal object with element type %s is not implemented.",
+ "Copying a Literal object with element type %s is not "
+ "implemented.",
PrimitiveType_Name(subshape().element_type()).c_str());
}
}
return Status::OK();
}
-Status Literal::CopyFrom(const Literal& src_literal,
+Status Literal::CopyFrom(const LiteralSlice& src_literal,
const ShapeIndex& dest_shape_index,
const ShapeIndex& src_shape_index) {
const Shape& dest_subshape =
@@ -402,36 +444,33 @@ Status Literal::CopyFrom(const Literal& src_literal,
ShapeUtil::HumanString(src_subshape).c_str());
}
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
- }
-
- // Determine if this index is in the part of this literal that we want to
- // copy over from src_literal.
- bool in_subtree_to_copy = true;
- for (int i = 0; i < dest_shape_index.size(); ++i) {
- if (index[i] != dest_shape_index[i]) {
- in_subtree_to_copy = false;
- break;
- }
- }
- if (!in_subtree_to_copy) {
- continue;
- }
-
- // Construct the index of the corresponding piece in the source literal.
- ShapeIndex src_piece_index = src_shape_index;
- for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
- src_piece_index.push_back(index[i]);
- }
+ return root_piece_->ForEachMutableSubpieceWithStatus(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (!ShapeUtil::IsArray(piece->subshape())) {
+ return Status::OK();
+ }
- TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index)));
- }
- return Status::OK();
-}
+ // Determine if this index is in the part of this literal that we want
+ // to copy over from src_literal.
+ bool in_subtree_to_copy = true;
+ for (int i = 0; i < dest_shape_index.size(); ++i) {
+ if (index[i] != dest_shape_index[i]) {
+ in_subtree_to_copy = false;
+ break;
+ }
+ }
+ if (!in_subtree_to_copy) {
+ return Status::OK();
+ }
+ // Construct the index of the corresponding piece in the source literal.
+ ShapeIndex src_piece_index = src_shape_index;
+ for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
+ src_piece_index.push_back(index[i]);
+ }
+ TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
+ return Status::OK();
+ });
+} // namespace xla
Status Literal::MoveFrom(Literal&& src_literal,
const ShapeIndex& dest_shape_index) {
@@ -444,37 +483,32 @@ Status Literal::MoveFrom(Literal&& src_literal,
ShapeUtil::HumanString(src_literal.shape()).c_str());
}
- if (!(owns_buffers_ && src_literal.owns_buffers_)) {
- return InvalidArgument(
- "Source and destination literals must both own their buffers (ie, not "
- "be views)");
- }
+ src_literal.root_piece_->ForEachSubpiece(
+ [&](const ShapeIndex& src_index, const Piece& src_piece) {
+ if (!ShapeUtil::IsArray(src_piece.subshape())) {
+ return;
+ }
- for (auto& pair : src_literal.pieces_) {
- const ShapeIndex& src_index = pair.first;
- Piece& src_piece = pair.second;
- if (!ShapeUtil::IsArray(src_piece.subshape())) {
- continue;
- }
+ ShapeIndex dest_index = dest_shape_index;
+ for (int64 i : src_index) {
+ dest_index.push_back(i);
+ }
+ Piece& dest_piece = piece(dest_index);
+ delete[] dest_piece.buffer();
+ dest_piece.set_buffer(src_piece.buffer());
+ delete dest_piece.sparse_indices();
+ dest_piece.set_sparse_indices(src_piece.sparse_indices());
+ });
- ShapeIndex dest_index = dest_shape_index;
- for (int64 i : src_index) {
- dest_index.push_back(i);
- }
- Piece& dest_piece = piece(dest_index);
- delete[] dest_piece.buffer();
- dest_piece.set_buffer(src_piece.buffer());
- delete dest_piece.sparse_indices();
- dest_piece.set_sparse_indices(src_piece.sparse_indices());
- }
+ src_literal.shape_ = MakeUnique<Shape>(ShapeUtil::MakeNil());
+ delete src_literal.root_piece_;
+ src_literal.root_piece_ = new LiteralBase::Piece();
+ src_literal.root_piece_->set_subshape(src_literal.shape_.get());
- src_literal.shape_ = ShapeUtil::MakeNil();
- src_literal.pieces_ = ShapeTree<Piece>(src_literal.shape_);
- src_literal.piece({}).set_subshape(&src_literal.shape_);
return Status::OK();
}
-Status Literal::CopySliceFrom(const Literal& src_literal,
+Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
tensorflow::gtl::ArraySlice<int64> src_base,
tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size) {
@@ -743,7 +777,7 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
return CreateR2FromArray2D(*value);
}
-std::unique_ptr<Literal> Literal::Relayout(
+std::unique_ptr<Literal> LiteralBase::Relayout(
const Layout& new_layout, const ShapeIndex& shape_index) const {
// Create new shape with 'new_layout' set at the given shape index.
Shape new_shape = shape();
@@ -755,7 +789,7 @@ std::unique_ptr<Literal> Literal::Relayout(
return result;
}
-std::unique_ptr<Literal> Literal::Relayout(
+std::unique_ptr<Literal> LiteralBase::Relayout(
const Shape& shape_with_layout) const {
CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
<< "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
@@ -774,7 +808,7 @@ std::unique_ptr<Literal> Literal::Relayout(
return result;
}
-StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
tensorflow::gtl::ArraySlice<int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Reshape does not support tuples.");
@@ -788,7 +822,8 @@ StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
}
// Because the layout is monotonic, we can simply reuse the same sequence of
// values without changing their order.
- output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions);
+ *output->mutable_shape_do_not_use() =
+ ShapeUtil::MakeShape(shape().element_type(), dimensions);
int64 elements_before = ShapeUtil::ElementsIn(shape());
int64 elements_after = ShapeUtil::ElementsIn(output->shape());
@@ -802,7 +837,79 @@ StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
return std::move(output);
}
-std::unique_ptr<Literal> Literal::Transpose(
+/* static */ std::unique_ptr<Literal> Literal::ReshapeSlice(
+ tensorflow::gtl::ArraySlice<int64> new_dimensions,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major,
+ const LiteralSlice& literal) {
+ int64 new_num_elements = 1;
+ for (int64 i = 0; i < new_dimensions.size(); ++i) {
+ new_num_elements *= new_dimensions[i];
+ }
+ CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
+ CHECK_EQ(new_dimensions.size(), minor_to_major.size());
+
+ auto new_literal = MakeUnique<Literal>(
+ ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
+
+ // Create a new shape with the given minor-to-major layout. This shape is used
+ // solely for converting linear address to multi-dimensional addresses when
+ // writing elements to the new literal.
+ Shape shape_with_layout = new_literal->shape();
+ *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
+
+ // Copy data into new literal, element-by-element.
+ for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
+ std::vector<int64> from_multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
+ std::vector<int64> to_multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
+ switch (literal.shape().element_type()) {
+ case PRED:
+ new_literal->Set<bool>(to_multi_index,
+ literal.Get<bool>(from_multi_index));
+ break;
+ case U8:
+ new_literal->Set<uint8>(to_multi_index,
+ literal.Get<uint8>(from_multi_index));
+ break;
+ case U32:
+ new_literal->Set<uint32>(to_multi_index,
+ literal.Get<uint32>(from_multi_index));
+ break;
+ case S32:
+ new_literal->Set<int32>(to_multi_index,
+ literal.Get<int32>(from_multi_index));
+ break;
+ case U64:
+ new_literal->Set<uint64>(to_multi_index,
+ literal.Get<uint64>(from_multi_index));
+ break;
+ case S64:
+ new_literal->Set<int64>(to_multi_index,
+ literal.Get<int64>(from_multi_index));
+ break;
+ case F32:
+ new_literal->Set<float>(to_multi_index,
+ literal.Get<float>(from_multi_index));
+ break;
+ case F64:
+ new_literal->Set<double>(to_multi_index,
+ literal.Get<double>(from_multi_index));
+ break;
+ case C64:
+ new_literal->Set<complex64>(to_multi_index,
+ literal.Get<complex64>(from_multi_index));
+ break;
+ default:
+ LOG(FATAL) << "Unhandled primitive element type: "
+ << PrimitiveType_Name(literal.shape().element_type());
+ }
+ }
+
+ return new_literal;
+}
+
+std::unique_ptr<Literal> LiteralBase::Transpose(
tensorflow::gtl::ArraySlice<int64> permutation) const {
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
@@ -819,8 +926,8 @@ std::unique_ptr<Literal> Literal::Transpose(
// representation intact.
// For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
// The shape with affine layout resulting from that operation will be
- // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
- // most minor.
+ // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized),
+ // the most minor.
//
// Essentially, given MinMaj(Di) the position of the Di dimension within the
// minor to major vector, and given T(Di) the index that the original Di
@@ -836,12 +943,11 @@ std::unique_ptr<Literal> Literal::Transpose(
std::unique_ptr<Literal> new_literal = CreateFromShape(permuted_shape);
DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()),
ShapeUtil::ByteSizeOf(shape()));
- std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(),
- root_piece().size_bytes());
+ std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
return new_literal;
}
-std::unique_ptr<Literal> Literal::Slice(
+std::unique_ptr<Literal> LiteralBase::Slice(
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices) const {
CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
@@ -909,20 +1015,20 @@ std::unique_ptr<Literal> Literal::Slice(
}
}
-Literal Literal::Clone() const {
+Literal LiteralBase::Clone() const {
Literal result(shape());
TF_CHECK_OK(result.CopyFrom(*this));
return result;
}
-std::unique_ptr<Literal> Literal::CloneToUnique() const {
+std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
auto result = MakeUnique<Literal>(shape());
TF_CHECK_OK(result->CopyFrom(*this));
return result;
}
-string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index) const {
+string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
CHECK(LayoutUtil::IsDenseArray(subshape));
switch (subshape.element_type()) {
@@ -962,8 +1068,8 @@ string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
}
}
-string Literal::GetSparseElementAsString(int64 sparse_element_number,
- const ShapeIndex& shape_index) const {
+string LiteralBase::GetSparseElementAsString(
+ int64 sparse_element_number, const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
CHECK(LayoutUtil::IsSparseArray(subshape));
switch (subshape.element_type()) {
@@ -1017,7 +1123,7 @@ string Literal::GetSparseElementAsString(int64 sparse_element_number,
}
}
-StatusOr<int64> Literal::GetIntegralAsS64(
+StatusOr<int64> LiteralBase::GetIntegralAsS64(
tensorflow::gtl::ArraySlice<int64> multi_index) const {
CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) {
@@ -1070,7 +1176,7 @@ Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
return Status::OK();
}
-tensorflow::gtl::ArraySlice<int64> Literal::GetSparseIndex(
+tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
int64 sparse_element_number, const ShapeIndex& shape_index) const {
const Piece& p = piece(shape_index);
CHECK_GE(sparse_element_number, 0);
@@ -1082,10 +1188,10 @@ void Literal::SortSparseElements(const ShapeIndex& shape_index) {
piece(shape_index).SortSparseElements();
}
-Literal Literal::GetFirstScalarLiteral() const {
- CHECK(ShapeUtil::IsArray(shape_));
- CHECK_GT(ShapeUtil::ElementsIn(shape_), 0);
- switch (shape_.element_type()) {
+Literal LiteralBase::GetFirstScalarLiteral() const {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_GT(ShapeUtil::ElementsIn(shape()), 0);
+ switch (shape().element_type()) {
case PRED:
return std::move(*Literal::CreateR0<bool>(GetFirstElement<bool>()));
// 8 bit types.
@@ -1121,11 +1227,11 @@ Literal Literal::GetFirstScalarLiteral() const {
case U64:
return std::move(*Literal::CreateR0<uint64>(GetFirstElement<uint64>()));
default:
- LOG(FATAL) << "Unhandled primitive type " << shape_.element_type();
+ LOG(FATAL) << "Unhandled primitive type " << shape().element_type();
}
}
-void Literal::Piece::SortSparseElements() {
+void LiteralBase::Piece::SortSparseElements() {
switch (subshape().element_type()) {
case PRED:
SortSparseElementsInternal<bool>();
@@ -1176,7 +1282,7 @@ void Literal::Piece::SortSparseElements() {
}
template <typename NativeT>
-void Literal::Piece::SortSparseElementsInternal() {
+void LiteralBase::Piece::SortSparseElementsInternal() {
CHECK(LayoutUtil::IsSparseArray(subshape()));
int64 num_elements = sparse_indices()->index_count();
auto values = data<NativeT>();
@@ -1186,10 +1292,11 @@ void Literal::Piece::SortSparseElementsInternal() {
}
namespace {
-
-void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
+void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
bool print_layout, std::vector<string>* pieces) {
const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
+ CHECK(LayoutUtil::HasLayout(literal.shape()));
+ CHECK(LayoutUtil::HasLayout(subshape));
auto shape_to_string = [print_layout](const Shape& shape) {
if (print_layout) {
@@ -1348,13 +1455,14 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
} // namespace
-int64 Literal::sparse_element_count() const {
+int64 LiteralBase::sparse_element_count() const {
CHECK(LayoutUtil::IsSparseArray(shape()));
return sparse_indices()->index_count();
}
-string Literal::ToString(bool print_layout) const {
+string LiteralBase::ToString(bool print_layout) const {
std::vector<string> pieces;
+ CHECK(LayoutUtil::HasLayout(this->shape()));
ToStringHelper(*this, {}, print_layout, &pieces);
return tensorflow::str_util::Join(pieces, "");
}
@@ -1362,7 +1470,7 @@ string Literal::ToString(bool print_layout) const {
/* static */ std::unique_ptr<Literal> Literal::MakeTuple(
tensorflow::gtl::ArraySlice<const Literal*> elements) {
std::vector<Shape> element_shapes;
- for (const Literal* element : elements) {
+ for (const auto* element : elements) {
element_shapes.push_back(element->shape());
}
auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
@@ -1372,6 +1480,19 @@ string Literal::ToString(bool print_layout) const {
return literal;
}
+/* static */ std::unique_ptr<Literal> Literal::MakeTupleFromSlices(
+ tensorflow::gtl::ArraySlice<LiteralSlice> elements) {
+ std::vector<Shape> element_shapes;
+ for (const auto& element : elements) {
+ element_shapes.push_back(element.shape());
+ }
+ auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ for (int i = 0; i < elements.size(); ++i) {
+ TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i}));
+ }
+ return literal;
+}
+
/* static */ std::unique_ptr<Literal> Literal::MakeTupleOwned(
std::vector<std::unique_ptr<Literal>> elements) {
std::vector<Shape> element_shapes;
@@ -1387,7 +1508,7 @@ string Literal::ToString(bool print_layout) const {
return literal;
}
-void Literal::EachCellAsString(
+void LiteralBase::EachCellAsString(
const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
const string& value)>& per_cell) const {
if (ShapeUtil::HasZeroElements(shape())) {
@@ -1403,7 +1524,7 @@ void Literal::EachCellAsString(
namespace {
template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
- const Literal& src_literal, const ConverterType& converter) {
+ const LiteralBase& src_literal, const ConverterType& converter) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
src_literal.shape(),
@@ -1419,7 +1540,8 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
}
template <typename NativeSrcT, typename NativeDestT>
-std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
+std::unique_ptr<Literal> ConvertBetweenNativeTypes(
+ const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
@@ -1428,7 +1550,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
std::unique_ptr<Literal>>::type
-BitcastBetweenNativeTypes(const Literal& src_literal) {
+BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) {
return tensorflow::bit_cast<NativeDestT>(src);
};
@@ -1436,19 +1558,19 @@ BitcastBetweenNativeTypes(const Literal& src_literal) {
src_literal, converter);
}
-// This template specialization is here to make the compiler happy. bit_cast has
-// a static check that the types are the same size. This specialization should
-// never be used because the source and destination types are checked for
-// identical sizes higher up.
+// This template specialization is here to make the compiler happy. bit_cast
+// has a static check that the types are the same size. This specialization
+// should never be used because the source and destination types are checked
+// for identical sizes higher up.
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
std::unique_ptr<Literal>>::type
-BitcastBetweenNativeTypes(const Literal& src_literal) {
+BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
LOG(FATAL) << "Invalid bitcast between types of different sizes.";
}
template <PrimitiveType primitive_src_type>
-std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
+std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
auto result_literal = MakeUnique<Literal>(
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
@@ -1466,7 +1588,7 @@ std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
}
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
-std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal,
+std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
bool bitcast) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
if (bitcast) {
@@ -1486,7 +1608,7 @@ std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal,
template <PrimitiveType primitive_src_type>
StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
- const Literal& src_literal, PrimitiveType primitive_dest_type,
+ const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
bool bitcast) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
@@ -1521,7 +1643,8 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
}
StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
- const Literal& literal, PrimitiveType primitive_dest_type, bool bitcast) {
+ const LiteralBase& literal, PrimitiveType primitive_dest_type,
+ bool bitcast) {
TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
if (literal.shape().element_type() == primitive_dest_type) {
return literal.CloneToUnique();
@@ -1555,17 +1678,18 @@ StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
} // namespace
-StatusOr<std::unique_ptr<Literal>> Literal::Convert(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
PrimitiveType primitive_dest_type) const {
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
}
-StatusOr<std::unique_ptr<Literal>> Literal::BitcastConvert(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
PrimitiveType primitive_dest_type) const {
if (primitive_util::BitWidth(shape().element_type()) !=
primitive_util::BitWidth(primitive_dest_type)) {
return InvalidArgument(
- "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
+ "Cannot bitcast convert from %s to %s, bit widths are different: %d "
+ "!= "
"%d",
PrimitiveType_Name(shape().element_type()).c_str(),
PrimitiveType_Name(primitive_dest_type).c_str(),
@@ -1575,7 +1699,7 @@ StatusOr<std::unique_ptr<Literal>> Literal::BitcastConvert(
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
}
-StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
const Shape& dest_shape, bool round_f32_to_bf16) const {
if (!ShapeUtil::IsTuple(dest_shape)) {
if (round_f32_to_bf16 && shape().element_type() == F32 &&
@@ -1590,7 +1714,7 @@ StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape(
}
std::vector<Literal> elements;
for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
- auto element = LiteralView::Create(*this, {i});
+ auto element = LiteralSlice(*this, {i});
TF_ASSIGN_OR_RETURN(
auto new_element,
element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
@@ -1602,8 +1726,8 @@ StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape(
}
template <typename NativeT>
-bool Literal::Piece::EqualElementsInternal(
- const Literal::Piece& other, std::vector<int64>* multi_index) const {
+bool LiteralBase::Piece::EqualElementsInternal(
+ const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
if (multi_index->size() == ShapeUtil::Rank(subshape())) {
return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
}
@@ -1617,7 +1741,7 @@ bool Literal::Piece::EqualElementsInternal(
return true;
}
-bool Literal::Piece::EqualElements(const Literal::Piece& other) const {
+bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
std::vector<int64> multi_index;
@@ -1645,32 +1769,31 @@ bool Literal::Piece::EqualElements(const Literal::Piece& other) const {
case C64:
return EqualElementsInternal<complex64>(other, &multi_index);
default:
- LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type "
+ LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
<< PrimitiveType_Name(subshape().element_type());
}
}
-bool Literal::operator==(const Literal& other) const {
+bool LiteralBase::operator==(const LiteralBase& other) const {
if (!ShapeUtil::Compatible(shape(), other.shape())) {
return false;
}
- for (const auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- const Piece& piece = pair.second;
- if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
- }
- const Piece& other_piece = other.piece(index);
- if (!piece.EqualElements(other_piece)) {
- return false;
- }
- }
- return true;
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
+ }
+
+ const Piece& other_piece = other.piece(index);
+ if (!piece.EqualElements(other_piece)) {
+ return false;
+ }
+ return true;
+ });
}
namespace {
-
template <typename NativeT>
static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
NativeT value) {
@@ -1684,11 +1807,11 @@ static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
} // namespace
-bool Literal::IsAll(int8 value) const {
- for (const auto& pair : pieces_) {
- const Piece& piece = pair.second;
+bool LiteralBase::IsAll(int8 value) const {
+ return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
+ const Piece& piece) {
if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
+ return true;
}
auto piece_is_all = [&]() {
@@ -1741,41 +1864,41 @@ bool Literal::IsAll(int8 value) const {
if (!piece_is_all()) {
return false;
}
- }
- return true;
-}
+ return true;
+ });
+} // namespace xla
-bool Literal::IsAllFloat(float value) const {
- for (const auto& pair : pieces_) {
- const Piece& piece = pair.second;
- if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
- }
+bool LiteralBase::IsAllFloat(float value) const {
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
+ }
- auto piece_is_all = [&]() {
- switch (shape().element_type()) {
- case F32:
- return AllElementsEqualValue<float>(piece.data<float>(), value);
- case F64:
- return AllElementsEqualValue<double>(piece.data<double>(), value);
- case F16:
- return AllElementsEqualValue<half>(piece.data<half>(),
- static_cast<half>(value));
- case BF16:
- return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
- static_cast<bfloat16>(value));
- default:
+ auto piece_is_all = [&]() {
+ switch (shape().element_type()) {
+ case F32:
+ return AllElementsEqualValue<float>(piece.data<float>(), value);
+ case F64:
+ return AllElementsEqualValue<double>(piece.data<double>(), value);
+ case F16:
+ return AllElementsEqualValue<half>(piece.data<half>(),
+ static_cast<half>(value));
+ case BF16:
+ return AllElementsEqualValue<bfloat16>(
+ piece.data<bfloat16>(), static_cast<bfloat16>(value));
+ default:
+ return false;
+ }
+ };
+ if (!piece_is_all()) {
return false;
- }
- };
- if (!piece_is_all()) {
- return false;
- }
- }
- return true;
+ }
+ return true;
+ });
}
-bool Literal::IsAllComplex(complex64 value) const {
+bool LiteralBase::IsAllComplex(complex64 value) const {
switch (shape().element_type()) {
case C64:
return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
@@ -1785,93 +1908,93 @@ bool Literal::IsAllComplex(complex64 value) const {
}
}
-bool Literal::IsAllFirst() const {
- for (const auto& pair : pieces_) {
- const Piece& piece = pair.second;
- if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
- }
-
- // Empty shapes are not all the first element since there is no first
- // element.
- if (ShapeUtil::HasZeroElements(piece.subshape())) {
- return false;
- }
- auto piece_is_all = [&]() {
- switch (piece.subshape().element_type()) {
- case PRED: {
- auto data = piece.data<bool>();
- return AllElementsEqualValue<bool>(data, data[0]);
- }
- // 8 bit types
- case S8: {
- auto data = piece.data<int8>();
- return AllElementsEqualValue<int8>(data, data[0]);
- }
- case U8: {
- auto data = piece.data<uint8>();
- return AllElementsEqualValue<uint8>(data, data[0]);
- }
- // 16 bit types
- case BF16: {
- auto data = piece.data<bfloat16>();
- return AllElementsEqualValue<bfloat16>(data, data[0]);
- }
- case F16: {
- auto data = piece.data<half>();
- return AllElementsEqualValue<half>(data, data[0]);
+bool LiteralBase::IsAllFirst() const {
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
}
- case S16: {
- auto data = piece.data<int16>();
- return AllElementsEqualValue<int16>(data, data[0]);
- }
- case U16: {
- auto data = piece.data<uint16>();
- return AllElementsEqualValue<uint16>(data, data[0]);
- }
- // 32 bit types
- case F32: {
- auto data = piece.data<float>();
- return AllElementsEqualValue<float>(data, data[0]);
- }
- case U32: {
- auto data = piece.data<uint32>();
- return AllElementsEqualValue<uint32>(data, data[0]);
- }
- case S32: {
- auto data = piece.data<int32>();
- return AllElementsEqualValue<int32>(data, data[0]);
- }
- // 64 bit types
- case C64: {
- auto data = piece.data<complex64>();
- return AllElementsEqualValue<complex64>(data, data[0]);
- }
- case F64: {
- auto data = piece.data<double>();
- return AllElementsEqualValue<double>(data, data[0]);
- }
- case S64: {
- auto data = piece.data<int64>();
- return AllElementsEqualValue<int64>(data, data[0]);
- }
- case U64: {
- auto data = piece.data<uint64>();
- return AllElementsEqualValue<uint64>(data, data[0]);
- }
- default:
+
+ // Empty shapes are not all the first element since there is no first
+ // element.
+ if (ShapeUtil::HasZeroElements(piece.subshape())) {
return false;
- }
- };
+ }
+ auto piece_is_all = [&]() {
+ switch (piece.subshape().element_type()) {
+ case PRED: {
+ auto data = piece.data<bool>();
+ return AllElementsEqualValue<bool>(data, data[0]);
+ }
+ // 8 bit types
+ case S8: {
+ auto data = piece.data<int8>();
+ return AllElementsEqualValue<int8>(data, data[0]);
+ }
+ case U8: {
+ auto data = piece.data<uint8>();
+ return AllElementsEqualValue<uint8>(data, data[0]);
+ }
+ // 16 bit types
+ case BF16: {
+ auto data = piece.data<bfloat16>();
+ return AllElementsEqualValue<bfloat16>(data, data[0]);
+ }
+ case F16: {
+ auto data = piece.data<half>();
+ return AllElementsEqualValue<half>(data, data[0]);
+ }
+ case S16: {
+ auto data = piece.data<int16>();
+ return AllElementsEqualValue<int16>(data, data[0]);
+ }
+ case U16: {
+ auto data = piece.data<uint16>();
+ return AllElementsEqualValue<uint16>(data, data[0]);
+ }
+ // 32 bit types
+ case F32: {
+ auto data = piece.data<float>();
+ return AllElementsEqualValue<float>(data, data[0]);
+ }
+ case U32: {
+ auto data = piece.data<uint32>();
+ return AllElementsEqualValue<uint32>(data, data[0]);
+ }
+ case S32: {
+ auto data = piece.data<int32>();
+ return AllElementsEqualValue<int32>(data, data[0]);
+ }
+ // 64 bit types
+ case C64: {
+ auto data = piece.data<complex64>();
+ return AllElementsEqualValue<complex64>(data, data[0]);
+ }
+ case F64: {
+ auto data = piece.data<double>();
+ return AllElementsEqualValue<double>(data, data[0]);
+ }
+ case S64: {
+ auto data = piece.data<int64>();
+ return AllElementsEqualValue<int64>(data, data[0]);
+ }
+ case U64: {
+ auto data = piece.data<uint64>();
+ return AllElementsEqualValue<uint64>(data, data[0]);
+ }
+ default:
+ return false;
+ }
+ };
- if (!piece_is_all()) {
- return false;
- }
- }
- return true;
+ if (!piece_is_all()) {
+ return false;
+ }
+ return true;
+ });
}
-bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
+bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
CHECK(ShapeUtil::IsArray(shape()));
switch (shape().element_type()) {
case U8:
@@ -1904,7 +2027,6 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
}
namespace {
-
template <typename RepeatedFieldT, typename NativeT>
void CopyToRepeatedField(RepeatedFieldT* dest,
const tensorflow::gtl::ArraySlice<NativeT> src) {
@@ -1913,7 +2035,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest,
} // namespace
-void Literal::Piece::WriteToProto(LiteralProto* proto) const {
+void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
*proto->mutable_shape() = subshape();
switch (subshape().element_type()) {
case PRED:
@@ -1969,18 +2091,17 @@ void Literal::Piece::WriteToProto(LiteralProto* proto) const {
}
}
-const void* Literal::Piece::untyped_data() const {
+const void* LiteralBase::Piece::untyped_data() const {
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
return buffer();
}
-void* Literal::Piece::untyped_data() {
+void* LiteralBase::Piece::untyped_data() {
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
return buffer();
}
namespace {
-
template <typename RepeatedFieldT, typename NativeT>
Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
const RepeatedFieldT& src) {
@@ -1995,7 +2116,7 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
} // namespace
-Status Literal::Piece::CopyFromProto(const LiteralProto& proto) {
+Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
// These conditions should have been checked in Literal::CreateFromProto.
TF_RET_CHECK(proto.has_shape());
TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
@@ -2062,21 +2183,19 @@ Status Literal::Piece::CopyFromProto(const LiteralProto& proto) {
return Status::OK();
}
-LiteralProto Literal::ToProto() const {
+LiteralProto LiteralBase::ToProto() const {
LiteralProto proto;
- for (const auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- const Piece& piece = pair.second;
-
- LiteralProto* proto_piece = &proto;
- for (int64 i : index) {
- while (proto_piece->tuple_literals_size() <= i) {
- proto_piece->add_tuple_literals();
- }
- proto_piece = proto_piece->mutable_tuple_literals(i);
- }
- piece.WriteToProto(proto_piece);
- }
+ root_piece().ForEachSubpiece(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ LiteralProto* proto_piece = &proto;
+ for (int64 i : index) {
+ while (proto_piece->tuple_literals_size() <= i) {
+ proto_piece->add_tuple_literals();
+ }
+ proto_piece = proto_piece->mutable_tuple_literals(i);
+ }
+ piece.WriteToProto(proto_piece);
+ });
if (LayoutUtil::IsSparseArray(shape())) {
CopyToRepeatedField(proto.mutable_sparse_indices(),
@@ -2098,33 +2217,39 @@ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
auto literal = MakeUnique<Literal>(proto.shape());
- for (auto& pair : literal->pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- const LiteralProto* proto_element = &proto;
- for (int64 i : index) {
- TF_RET_CHECK(i < proto_element->tuple_literals_size());
- proto_element = &proto_element->tuple_literals(i);
- }
+ TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
+ [&](const ShapeIndex& index, Piece* piece) {
+ const LiteralProto* proto_element = &proto;
+ for (int64 i : index) {
+ CHECK(i < proto_element->tuple_literals_size());
+ proto_element = &proto_element->tuple_literals(i);
+ }
- if (ShapeUtil::IsTuple(piece.subshape())) {
- if (proto_element->tuple_literals_size() !=
- ShapeUtil::TupleElementCount(piece.subshape())) {
- return InvalidArgument(
- "Expected %lld tuple elements in LiteralProto, has %d",
- ShapeUtil::TupleElementCount(piece.subshape()),
- proto_element->tuple_literals_size());
- }
- continue;
- }
+ if (ShapeUtil::IsTuple(piece->subshape())) {
+ if (proto_element->tuple_literals_size() !=
+ ShapeUtil::TupleElementCount(piece->subshape())) {
+ return InvalidArgument(
+ "Expected %lld tuple elements in LiteralProto, has %d",
+ ShapeUtil::TupleElementCount(piece->subshape()),
+ proto_element->tuple_literals_size());
+ }
+ return Status::OK();
+ }
- TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape()));
- TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element));
- }
+ CHECK(ShapeUtil::IsArray(piece->subshape()));
+ TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
+
+ return Status::OK();
+ }));
return std::move(literal);
}
-const void* Literal::untyped_data(const ShapeIndex& shape_index) const {
+/* static */ string Literal::MultiIndexAsString(
+ tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
+}
+
+const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
return piece(shape_index).untyped_data();
}
@@ -2132,11 +2257,11 @@ void* Literal::untyped_data(const ShapeIndex& shape_index) {
return piece(shape_index).untyped_data();
}
-int64 Literal::size_bytes(const ShapeIndex& shape_index) const {
+int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
return piece(shape_index).size_bytes();
}
-string Literal::GetR1U8AsString() const {
+string LiteralBase::GetR1U8AsString() const {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(shape().element_type(), U8);
@@ -2144,12 +2269,14 @@ string Literal::GetR1U8AsString() const {
ShapeUtil::ElementsIn(shape()));
}
-/* static */ const LiteralView LiteralView::Create(
- const Literal& literal, const ShapeIndex& view_root) {
- return LiteralView(literal, view_root);
-}
+LiteralSlice::LiteralSlice(const LiteralBase& literal)
+ : LiteralBase(), root_piece_(&literal.root_piece()) {}
-size_t Literal::Hash() const {
+LiteralSlice::LiteralSlice(const LiteralBase& literal,
+ const ShapeIndex& view_root)
+ : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
+
+size_t LiteralBase::Hash() const {
using tensorflow::Hash64;
using tensorflow::Hash64Combine;
@@ -2170,46 +2297,4 @@ size_t Literal::Hash() const {
return hash_value;
}
-LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) {
- shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root);
- pieces_ = ShapeTree<Piece>(shape_);
- owns_buffers_ = false;
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
-
- ShapeIndex src_index = view_root;
- for (int64 i : index) {
- src_index.push_back(i);
- }
- const Piece& src_piece = literal.piece(src_index);
- piece.set_buffer(src_piece.buffer());
- piece.set_sparse_indices(src_piece.sparse_indices());
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
- }
-}
-
-LiteralView::~LiteralView() {}
-
-LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); }
-
-LiteralView& LiteralView::operator=(const LiteralView& other) {
- CopyFrom(other);
- return *this;
-}
-
-void LiteralView::CopyFrom(const LiteralView& other) {
- // We can't use the default copy-constructor/copy-assignment because
- // Piece::subshape_ points to subshapes within the Shape of the owning
- // Literal/LiteralView.
- shape_ = other.shape();
- pieces_ = other.pieces_;
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
- }
- owns_buffers_ = false;
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index c6bd03bf21..8d51aa3881 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/sparse_index_array.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -52,14 +51,491 @@ limitations under the License.
namespace xla {
+// Forward declare Literal and LiteralSlice class to be used by the creation
+// methods in the base class.
+class Literal;
+class LiteralSlice;
+
+// Abstract base class for literals.
+class LiteralBase {
+ public:
+ virtual ~LiteralBase() = 0;
+
+ // Literals are equal if they have compatible shapes and the same data
+ // values. Layout is not compared.
+ bool operator==(const LiteralBase& other) const;
+ bool operator!=(const LiteralBase& other) const { return !(*this == other); }
+
+ // Returns the shape of the literal.
+ const Shape& shape() const { return root_piece().subshape(); }
+
+ // Serialize to proto.
+ LiteralProto ToProto() const;
+
+ // Returns an ArraySlice of the array for this literal for the given NativeT
+ // (e.g., float). CHECKs if the subshape of the literal at the given
+ // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
+ // to native type.
+ template <typename NativeT>
+ tensorflow::gtl::ArraySlice<NativeT> data(
+ const ShapeIndex& shape_index = {}) const;
+
+ // Returns a const pointer to the sparse index array. Returns nullptr if the
+ // literal is not a sparse array.
+ const SparseIndexArray* sparse_indices(
+ const ShapeIndex& shape_index = {}) const;
+
+ // Returns a const pointer to (or size of) the underlying buffer holding the
+ // array at the given shape index. CHECKs if the subshape of the literal at
+ // the given ShapeIndex is not array.
+ const void* untyped_data(const ShapeIndex& shape_index = {}) const;
+ int64 size_bytes(const ShapeIndex& shape_index = {}) const;
+
+ // Returns this literal's data as a string. This literal must be a rank-1 U8
+ // array.
+ string GetR1U8AsString() const;
+
+ // Returns a string representation of the literal value.
+ // Warning: this function can take minutes for multi-million element Literals.
+ string ToString(bool print_layout = false) const;
+
+ // Gets an element in the literal at the given index. The multi_index is
+ // CHECKed against the dimension sizes.
+ template <typename NativeT>
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const;
+ // Overloads of Get for array literals. CHECKs if the literal is not
+ // array-shaped and dense.
+ template <typename NativeT>
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
+
+ // Returns the element value at index (0, ..., 0), however many zeroes are
+ // required for that index.
+ template <typename NativeT>
+ NativeT GetFirstElement() const;
+
+ // As Get(), but determines the correct type and converts the value
+ // into text.
+ string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index = {}) const;
+ // As GetSparseElement(), but determines the correct type and converts the
+ // value into text.
+ string GetSparseElementAsString(int64 sparse_element_number,
+ const ShapeIndex& shape_index = {}) const;
+ // As Get(), but determines the correct type and converts the value into
+ // int64. This literal must be an array.
+ StatusOr<int64> GetIntegralAsS64(
+ tensorflow::gtl::ArraySlice<int64> multi_index) const;
+
+ // Returns the multi-index of the element in a sparse literal at the given
+ // sparse element number. The sparse element number is the position with in
+ // the sparse array's list of (index, value) pairs, and is checked against the
+ // total number of (index, value) pairs in the sparse array.
+ tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
+ int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
+
+ // Returns the value of the element in a sparse literal at the given sparse
+ // element number. The sparse element number is the position with in the
+ // sparse array's list of (index, value) pairs, and is checked against the
+ // total number of (index, value) pairs in the sparse array.
+ template <typename NativeT>
+ NativeT GetSparseElement(int64 sparse_element_number,
+ const ShapeIndex& shape_index = {}) const;
+
+ // Invokes the "per cell" callback for each element in the provided
+ // literal with the element's indices and a string representation of
+ // the element's value.
+ //
+ // This function is useful if you want a polymorphic representation
+ // of the tensor's elements (turning it to a string for something
+ // like representation in a protobuf).
+ //
+ // This literal must have a dense layout.
+ void EachCellAsString(
+ const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ const string& value)>& per_cell) const;
+ template <typename NativeT>
+ void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ NativeT value)>
+ per_cell) const;
+
+ // Returns whether every element in this literal is equal to value.
+ //
+ // value is an int8 because we expect this to be called with small
+ // compile-time constants (0, -1, etc.) and so that whatever value you pass
+ // can be represented exactly by floating-point types as small as 16 bits.
+ //
+ // If value doesn't fit in this literal's type, returns false. Values of 1/0
+ // are considered equal to true/false; other values are not considered equal
+ // to true. Also if this literal is not array-shaped false is returned.
+ bool IsAll(int8 value) const;
+
+ // Like IsAll(const Literal&, int8), except we check whether the literal is
+ // equal to a particular floating-point number.
+ //
+ // If the literal is not a floating-point value, this always returns false.
+ //
+ // This casts value to the type of literal, then compares using ==. The usual
+ // admonishments about floating-point equality checks apply. We expect you to
+ // use this to check for values that can be expressed precisely as a float,
+ // e.g. -0.5. Also if this literal is not array-shaped false is returned.
+ bool IsAllFloat(float value) const;
+
+ // Like IsAll(const Literal&, int8), except we check whether the literal is
+ // equal to a particular complex number.
+ //
+ // If the literal is not a complex value, this always returns false.
+ //
+ // This casts value to the type of literal, then compares using ==. The usual
+ // admonishments about floating-point equality checks apply. We expect you to
+ // use this to check for complex values that can be expressed precisely as
+ // float pairs e.g. (-0.5, 1.0).
+ //
+ // This literal must have a dense layout.
+ bool IsAllComplex(complex64 value) const;
+
+ // Literal consists entirely of the first element of the literal.
+ bool IsAllFirst() const;
+
+ // Returns whether this literal is zero at the specified index. This literal
+ // must be an array with a dense layout.
+ bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
+
+ // Returns the count of the elements in the array at the given shape index in
+ // this literal.
+ int64 element_count(const ShapeIndex& index = {}) const {
+ return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
+ }
+
+ // Return the count of the elements in the sparse array at the given shape
+ // index in this literal, which will be no larger than
+ // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
+ int64 sparse_element_count() const;
+
+ // Compute a hash for this literal. This literal must not be a sparse tensor
+ // or a tuple containing a sparse tensor.
+ size_t Hash() const;
+
+ // Converts this literal to the given shape. Returns an error is the
+ // conversion is not possible.
+ //
+ // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
+ // instead of truncation; otherwise, truncation is used.
+ //
+ // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
+ // the default behavior.
+ StatusOr<std::unique_ptr<Literal>> ConvertToShape(
+ const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
+
+ // Converts this literal to another primitive type using a bitcast
+ // conversion. The to and from primitive types must have the same bit
+ // width. Returns an error if the conversion is not possible. This literal
+ // must be array-shaped.
+ StatusOr<std::unique_ptr<Literal>> BitcastConvert(
+ PrimitiveType primitive_dest_type) const;
+
+ // Converts this literal to another primitive type. Returns an error if the
+ // conversion is not possible. This literal must be array-shaped.
+ StatusOr<std::unique_ptr<Literal>> Convert(
+ PrimitiveType primitive_dest_type) const;
+
+ // Returns a literal scalar representing the first element.
+ Literal GetFirstScalarLiteral() const;
+
+ // Clones the underlying buffers into a new Literal, or new
+ // std::unique_ptr<Literal>.
+ Literal Clone() const;
+ std::unique_ptr<Literal> CloneToUnique() const;
+
+ // TODO(b/67651157): The methods below which perform computation on Literals
+ // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
+ // evaluator code which operates on Literals.
+ //
+ // Creates a new value that has the equivalent value as this
+ // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
+ // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
+ // minor-to-major dimension layout and the value in the cell at any given
+ // logical index (i0, i1) will be the same.
+ //
+ // For tuple shaped literals, shape_index should be used to select the inner
+ // array that the new layout applies to.
+ //
+ // Note: this is useful when the client wants to ensure that a value placed in
+ // the XLA allocation tracker has a particular layout; for efficiency
+ // purposes or avoiding unimplemented operation/layout combinations.
+ std::unique_ptr<Literal> Relayout(const Layout& new_layout,
+ const ShapeIndex& shape_index = {}) const;
+
+ // An overload of Relayout which changes the layout of the entire shape rather
+ // than being limited to a single array within the shape.
+ std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
+
+ // Creates a new literal by reshaping this literal to have the given
+ // dimensions. The total number of elements must not change; The
+ // implementation currently only supports monotonic dim0-major layouts.
+ // This literal must be an array.
+ StatusOr<std::unique_ptr<Literal>> Reshape(
+ tensorflow::gtl::ArraySlice<int64> dimensions) const;
+
+ // Creates a new literal by reordering the dimensions of this literal.
+ // The given `permutation` must be a permutation of the dimension numbers
+ // in the original literal, and it specifies the order of the new dimensions
+ // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
+ // For example, a transpose call on a literal of shape [3 x 8 x 4] and
+ // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
+ // This literal must be an array.
+ std::unique_ptr<Literal> Transpose(
+ tensorflow::gtl::ArraySlice<int64> permutation) const;
+
+ // Creates a sub-array from this literal by extracting the indices
+ // [start_index, limit_index) of each dimension. The result literal has the
+ // same rank and layout as for the given literal. The number of indices in
+ // start_indices and limit_indices must be the rank of the literal, and the
+ // indices follow the order of the dimensions.
+ // This literal must be an array.
+ std::unique_ptr<Literal> Slice(
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices) const;
+
+ // Creates a literal with a prepended dimension with bound "times"; e.g. a
+ // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
+ // literal replicated four times.
+ // This literal must be an array.
+ template <typename NativeT>
+ std::unique_ptr<Literal> Replicate(int64 times) const;
+
+ // Creates a new Literal object with the shape specified as parameter.
+ // The content of the literal values is the default value of the primitive
+ // type of literal itself (0 for numeric types, and false for predicates).
+ static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
+
+ protected:
+ // A data structure representing a subshape at a particular ShapeIndex within
+ // the literal. For array-shaped ShapeIndexes, this data structure holds the
+ // pointer to the memory allocated for the array data.
+ class Piece {
+ public:
+ // Returns the buffer holding the array data for this piece as an array
+ // slice. This piece must be array-shaped.
+ template <typename NativeT>
+ tensorflow::gtl::ArraySlice<NativeT> data() const;
+ template <typename NativeT>
+ tensorflow::gtl::MutableArraySlice<NativeT> data();
+
+ // Returns the buffer holding the array data for this piece as a void*. This
+ // piece must be array-shaped.
+ void* untyped_data();
+ const void* untyped_data() const;
+
+ // Gets or sets an element in the array at the given index. The multi_index
+ // is CHECKed against the dimension sizes of the array. This piece must be
+ // array-shaped.
+ template <typename NativeT>
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
+ template <typename NativeT>
+ void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
+
+ // Gets/sets the buffer holding the array data.
+ char* buffer() const { return buffer_; }
+ void set_buffer(char* buffer) { buffer_ = buffer; }
+
+ // The array of multi-indices that provide the locations of non-zero
+ // elements in a sparse array. Only used if
+ // LayoutUtil::IsSparseArray(shape()) is true.
+ SparseIndexArray* sparse_indices() const { return sparse_indices_; }
+ void set_sparse_indices(SparseIndexArray* sparse_indices) {
+ sparse_indices_ = sparse_indices;
+ }
+
+ // Gets or sets the subshape of this piece. This reference points to a
+ // subshape within the shape in the containing Literal (Literal::shape_).
+ const Shape& subshape() const { return *subshape_; }
+ void set_subshape(const Shape* subshape) { subshape_ = subshape; }
+
+ // Returns the size in bytes of the buffer holding the array data.
+ int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
+
+ // Returns the number of elements in this piece's array.
+ int64 element_count() const {
+ // If this is a sparse array, use the number of elements represented by
+ // the indices in the associated SparseIndexArray.
+ return LayoutUtil::IsSparseArray(subshape())
+ ? sparse_indices()->index_count()
+ : ShapeUtil::ElementsIn(subshape());
+ }
+
+ // Returns the child piece at 'index' of this piece.
+ Piece& child(int64 index) { return children_[index]; }
+
+ // Adds a child piece to this piece's children.
+ void emplace_back(Piece child_piece) {
+ children_.emplace_back(std::move(child_piece));
+ }
+
+ // Returns the size of children pieces of this piece.
+ int64 children_size() { return children_.size(); }
+
+ // Visitor functions that recursively traverses the piece and calls the
+ // given function at each child piece. The function has the type:
+ // void (const ShapeIndex& index, const Piece& piece)
+ template <typename Fn>
+ void ForEachSubpiece(const Fn& func) const {
+ ShapeIndex index;
+ return ForEachHelper(
+ [&func](const ShapeIndex& index, const Piece& piece) {
+ func(index, piece);
+ return Status::OK();
+ },
+ *this, &index)
+ .IgnoreError();
+ }
+ // Same as above, but the function has the type:
+ // Status (const ShapeIndex& index, const Piece& piece)
+ // The first non-OK return value is returned by the function.
+ template <typename Fn>
+ Status ForEachSubpieceWithStatus(const Fn& func) const {
+ ShapeIndex index;
+ return ForEachHelper(func, *this, &index);
+ }
+ // Same as above, but the function has the type:
+ // Bool (const ShapeIndex& index, const Piece& piece)
+ // The first non-true return value is returned by the function.
+ template <typename Fn>
+ bool ForEachSubpieceWithBool(const Fn& func) const {
+ ShapeIndex index;
+ return ForEachHelperBool(func, *this, &index);
+ }
+ // Same as above, but the function has the type:
+ // Void (const ShapeIndex& index, Piece& piece)
+ template <typename Fn>
+ void ForEachMutableSubpiece(const Fn& func) {
+ ShapeIndex index;
+ return ForEachMutableHelper(
+ [&func](const ShapeIndex& index, Piece* piece) {
+ func(index, piece);
+ return Status::OK();
+ },
+ const_cast<xla::LiteralBase::Piece*>(this), &index)
+ .IgnoreError();
+ }
+ // Same as above, but the function has the type:
+ // Status (const ShapeIndex& index, Piece& piece)
+ // The first non-OK return value is returned by the function.
+ template <typename Fn>
+ Status ForEachMutableSubpieceWithStatus(const Fn& func) {
+ ShapeIndex index;
+ return ForEachMutableHelper(
+ func, const_cast<xla::LiteralBase::Piece*>(this), &index);
+ }
+
+ // Returns true if this piece and 'other' contain the same data. This piece
+ // and 'other' must be array-shaped and compatible.
+ bool EqualElements(const Piece& other) const;
+
+ // Writes the shape and data (if array-shaped) into the given proto.
+ void WriteToProto(LiteralProto* proto) const;
+
+ // Copy the data from 'src' into this piece's buffer. Shapes of this piece
+ // and src must be compatible.
+ Status CopyFrom(const Piece& src);
+
+ // Copies the data from the given proto into this piece. The shape of this
+ // piece must be equal (not just compatible) to the shape of the proto.
+ Status CopyFromProto(const LiteralProto& proto);
+
+ // Sorts the elements in a sparse array.
+ void SortSparseElements();
+
+ private:
+ // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
+ // The first non-OK (or non-true) value is returned by the function.
+ // The callable 'func' has the same signature as described above in
+ // ForEachSubpiece*.
+ template <typename Fn>
+ Status ForEachHelper(const Fn& func, const Piece& piece,
+ ShapeIndex* index) const {
+ TF_RETURN_IF_ERROR(func(*index, piece));
+ for (int64 i = 0; i < piece.children_.size(); ++i) {
+ index->push_back(i);
+ TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
+ index->pop_back();
+ }
+ return Status::OK();
+ }
+ template <typename Fn>
+ bool ForEachHelperBool(const Fn& func, const Piece& piece,
+ ShapeIndex* index) const {
+ if (!func(*index, piece)) {
+ return false;
+ }
+ for (int64 i = 0; i < piece.children_.size(); ++i) {
+ index->push_back(i);
+ if (!ForEachHelperBool(func, piece.children_[i], index)) {
+ return false;
+ }
+ index->pop_back();
+ }
+ return true;
+ }
+ template <typename Fn>
+ Status ForEachMutableHelper(const Fn& func, Piece* piece,
+ ShapeIndex* index) {
+ TF_RETURN_IF_ERROR(func(*index, piece));
+ for (int64 i = 0; i < piece->children_.size(); ++i) {
+ index->push_back(i);
+ TF_RETURN_IF_ERROR(
+ ForEachMutableHelper(func, &piece->children_[i], index));
+ index->pop_back();
+ }
+ return Status::OK();
+ }
+
+ // Recursive helper for EqualElements.
+ template <typename NativeT>
+ bool EqualElementsInternal(const Piece& other,
+ std::vector<int64>* multi_index) const;
+
+ // Helper for SortSparseElements that has the element type as a template
+ // parameter.
+ template <typename NativeT>
+ void SortSparseElementsInternal();
+
+ // For array-shaped pieces, this is the buffer holding the literal data.
+ char* buffer_ = nullptr;
+
+ // For sparse arrays, this is the array of indices.
+ SparseIndexArray* sparse_indices_ = nullptr;
+
+ // The shape of piece. This points into the shape of the containing Literal
+ // (Literal::shape_).
+ const Shape* subshape_ = nullptr;
+
+ // Children pieces for tuple shaped pieces.
+ std::vector<Piece> children_ = {};
+ }; // class Piece
+
+ const Piece& piece(const ShapeIndex& shape_index) const {
+ Piece* piece = &const_cast<Piece&>(root_piece());
+ for (const auto i : shape_index) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, piece->children_size());
+ piece = &piece->child(i);
+ }
+ return *piece;
+ }
+
+ // Returns the piece at the root of the shape.
+ virtual const Piece& root_piece() const = 0;
+
+ // LiteralSlice and Literal must access Pieces of other Literals.
+ friend class LiteralSlice;
+ friend class Literal;
+};
+
// Class representing literal values in XLA.
//
-// TODO(b/67651157): The methods in this class should be reduced to a minimal
-// set of methods which construct Literals and accessors methods. Other methods
-// which perform computation on Literals (Reshape, Slice, etc) should be moved
-// elsewhere, and perhaps combined with evaluator code which operates on
-// Literals.
-class Literal {
+// The underlying buffer and shape is always owned by this class.
+class Literal : public LiteralBase {
public:
Literal() : Literal(ShapeUtil::MakeNil()) {}
@@ -80,46 +556,156 @@ class Literal {
Literal(const Shape& shape, bool allocate_arrays);
Literal& operator=(Literal&& other);
- // Literals are equal if they have compatible shapes and the same data
- // values. Layout is not compared.
- bool operator==(const Literal& other) const;
- bool operator!=(const Literal& other) const { return !(*this == other); }
-
- // Serialize to and from a proto.
- static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
- const LiteralProto& proto);
- LiteralProto ToProto() const;
-
- // Return the shape of the literal.
- const Shape& shape() const { return shape_; }
-
// TODO(b/67651157): Remove this accessor. Literal users should not be able to
// mutate the shape as this can produce malformed Literals.
- Shape* mutable_shape_do_not_use() { return &shape_; }
+ Shape* mutable_shape_do_not_use() { return shape_.get(); }
- // Returns a (Mutable)ArraySlice view of the array for this literal for the
+ // Returns a MutableArraySlice view of the array for this literal for the
// given NativeT (e.g., float). CHECKs if the subshape of the literal at the
// given ShapeIndex is not array. See primitive_util.h for the mapping from
// XLA type to native type.
template <typename NativeT>
- tensorflow::gtl::ArraySlice<NativeT> data(
- const ShapeIndex& shape_index = {}) const;
- template <typename NativeT>
tensorflow::gtl::MutableArraySlice<NativeT> data(
const ShapeIndex& shape_index = {});
+ // Unhide const method from parent class.
+ using LiteralBase::data;
// Returns a pointer to the sparse index array. Returns nullptr if the literal
// is not a sparse array.
- const SparseIndexArray* sparse_indices(
- const ShapeIndex& shape_index = {}) const;
SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
- // Returns a pointer to (or size of) the underlying buffer holding the array
- // at the given shape index. CHECKs if the subshape of the literal at the
- // given ShapeIndex is not array.
- const void* untyped_data(const ShapeIndex& shape_index = {}) const;
+ // Returns a pointer to the underlying buffer holding the array at the given
+ // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
+ // is not array.
void* untyped_data(const ShapeIndex& shape_index = {});
- int64 size_bytes(const ShapeIndex& shape_index = {}) const;
+ // Unhide const method from parent class.
+ using LiteralBase::untyped_data;
+
+ // Populates a literal with a sparse layout with the given indices and values.
+ // Each index in the indices array is CHECKed against the dimensions in the
+ // literal's shape. If sort is true, then the indices and values will be
+ // sorted. If sort is false, then the indices and values are assumed to
+ // already be in sorted order. See CreateSparse for an example of how data
+ // are populated.
+ template <typename NativeT>
+ void PopulateSparse(SparseIndexArray indices,
+ tensorflow::gtl::ArraySlice<NativeT> values,
+ bool sort = true);
+
+ // Copy values from 'src_literal' rooted at 'src_shape_index' into this
+ // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
+ // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
+ // rooted at 'src_shape_index', but need not be arrays.
+ Status CopyFrom(const LiteralSlice& src_literal,
+ const ShapeIndex& dest_shape_index = {},
+ const ShapeIndex& src_shape_index = {});
+
+ // Similar to CopyFrom, but with move semantincs. The subshape of this literal
+ // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
+ // (layouts and shapes must match), but need not be arrays. The memory
+ // allocated in this literal for the subshape at dest_shape_index is
+ // deallocated, and the respective buffers are replaced with those in
+ // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
+ Status MoveFrom(Literal&& src_literal,
+ const ShapeIndex& dest_shape_index = {});
+
+ // Copies the values from src_literal, starting at src_base shape indexes,
+ // to this literal, starting at dest_base, where the copy size in each
+ // dimension is specified by copy_size.
+ // The src_literal and this literal must have the same primitive type,
+ // src_base+copy_size must fit the source literal dimensions, as well as
+ // dest_base+copy_size must fit the destination literal dimensions.
+ // Note: if either src_literal or this literal contains dimensions with zero
+ // element, then copy_size must be 0 in these dimensions while the
+ // corresponding base indices being 0.
+ // This literal and 'src_literal' must be arrays.
+ Status CopySliceFrom(const LiteralSlice& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size);
+
+ // Copies one element from src_literal[src_index] to (*this)[dest_index].
+ Status CopyElementFrom(const LiteralSlice& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_index,
+ tensorflow::gtl::ArraySlice<int64> dest_index);
+
+ // Sets an element in the literal at the given index. The multi_index is
+ // CHECKed against the dimension sizes.
+ template <typename NativeT>
+ void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index, NativeT value);
+ // Overloads of Set for array literals. CHECKs if the literal is not
+ // array-shaped and dense.
+ template <typename NativeT>
+ void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
+
+ // Appends the given element to the literal. If the elements are not appended
+ // in sorted order, then SortSparseElements should be called before calling
+ // other methods. This literal must have a sparse layout.
+ template <typename NativeT>
+ void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value, const ShapeIndex& shape_index = {});
+
+ // Sorts the elements in a sparse array.
+ void SortSparseElements(const ShapeIndex& shape_index = {});
+
+ // As Set(), but truncates `value` to the literal element type before storing.
+ // This literal must be an array.
+ Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
+ int64 value);
+
+ // Populate this literal with the given values. Examples:
+ //
+ // // Populate with floats.
+ // Array2D<float> float_values = ...
+ // literal.PopulateR2FromArray2D(values);
+ //
+ // // Populate with int32s.
+ // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
+ //
+ // The shape and element type of this literal must match given values. For
+ // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
+ // array of S32.
+ template <typename NativeT>
+ void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
+ void PopulateR1(const tensorflow::core::Bitmap& values);
+ template <typename NativeT>
+ void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
+ template <typename NativeT>
+ void PopulateFromArray(const Array<NativeT>& values);
+ template <typename NativeT>
+ void PopulateR2FromArray2D(const Array2D<NativeT>& values);
+ template <typename NativeT>
+ void PopulateR3FromArray3D(const Array3D<NativeT>& values);
+ template <typename NativeT>
+ void PopulateR4FromArray4D(const Array4D<NativeT>& values);
+
+ // Populates literal values by calling the generator function for every cell
+ // in this literal object.
+ //
+ // generator must be a callable of the type
+ // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
+ //
+ // This literal must have a dense layout.
+ template <typename NativeT, typename FnType>
+ Status Populate(const FnType& generator);
+
+ // A parallel version of Populate(). This can be used if the generator is
+ // thread-safe and the values for the shape's different elements are
+ // independent.
+ template <typename NativeT, typename FnType>
+ Status PopulateParallel(const FnType& generator);
+
+ // Fills this literal with the given value.
+ template <typename NativeT>
+ void PopulateWithValue(NativeT value);
+
+ // Factory methods below.
+ //
+
+ // Serialize from a proto.
+ static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
+ const LiteralProto& proto);
// Creates a new literal of a given rank. To minimize ambiguity (for users
// and the compiler) these CreateR[0-2] methods should explicitly specify the
@@ -167,10 +753,6 @@ class Literal {
values,
const Layout& layout);
- // Returns this literal's data as a string. This literal must be a rank-1 U8
- // array.
- string GetR1U8AsString() const;
-
// Creates a literal with a sparse layout and the given indices and values.
// The shape is initialized from the given dimensions. The minor dimension of
// the indices array must equal the rank of the shape (i.e. size of the
@@ -210,171 +792,16 @@ class Literal {
tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
tensorflow::gtl::ArraySlice<NativeT> values, bool sort = true);
- // Populates a literal with a sparse layout with the given indices and values.
- // Each index in the indices array is CHECKed against the dimensions in the
- // literal's shape. If sort is true, then the indices and values will be
- // sorted. If sort is false, then the indices and values are assumed to
- // already be in sorted order. See CreateSparse for an example of how data
- // are populated.
- template <typename NativeT>
- void PopulateSparse(SparseIndexArray indices,
- tensorflow::gtl::ArraySlice<NativeT> values,
- bool sort = true);
-
- // Creates a new Literal object with the shape specified as parameter.
- // The content of the literal values is the default value of the primitive
- // type of literal itself (0 for numeric types, and false for predicates).
- static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
-
- // Creates a new Literal object with its values havings the primitive_type
- // type, and with dimensions defined by the dimensions parameter.
- // The content of the literal values is the default value of the primitive
- // type of literal itself (0 for numeric types, and false for predicates).
- static std::unique_ptr<Literal> CreateFromDimensions(
- PrimitiveType primitive_type,
- tensorflow::gtl::ArraySlice<int64> dimensions);
-
- // Copy values from 'src_literal' rooted at 'src_shape_index' into this
- // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
- // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
- // rooted at 'src_shape_index', but need not be arrays.
- Status CopyFrom(const Literal& src_literal,
- const ShapeIndex& dest_shape_index = {},
- const ShapeIndex& src_shape_index = {});
-
- // Similar to CopyFrom, but with move semantincs. The subshape of this literal
- // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
- // (layouts and shapes must match), but need not be arrays. The memory
- // allocated in this literal for the subshape at dest_shape_index is
- // deallocated, and the respective buffers are replaced with those in
- // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
- Status MoveFrom(Literal&& src_literal,
- const ShapeIndex& dest_shape_index = {});
-
- // Copies the values from src_literal, starting at src_base shape indexes,
- // to this literal, starting at dest_base, where the copy size in each
- // dimension is specified by copy_size.
- // The src_literal and this literal must have the same primitive type,
- // src_base+copy_size must fit the source literal dimensions, as well as
- // dest_base+copy_size must fit the destination literal dimensions.
- // Note: if either src_literal or this literal contains dimensions with zero
- // element, then copy_size must be 0 in these dimensions while the
- // corresponding base indices being 0.
- // This literal and 'src_literal' must be arrays.
- Status CopySliceFrom(const Literal& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size);
-
- // Copies one element from src_literal[src_index] to (*this)[dest_index].
- Status CopyElementFrom(const Literal& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_index,
- tensorflow::gtl::ArraySlice<int64> dest_index);
-
- // Returns a vector containing the tuple elements of this Literal as separate
- // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
- // elements are moved into the new Literals; no data is copied. Upon return
- // this Literal is set to a nil shape (empty tuple)
- std::vector<Literal> DecomposeTuple();
-
- // This operation is the inverse of DecomposeTuple. The given elements are
- // moved into the tuple elements of a new tuple-shaped Literal which is
- // returned. Upon return, each of the Literals in 'elements' is set to a nil
- // shape (empty tuple).
- static Literal MoveIntoTuple(
- tensorflow::gtl::MutableArraySlice<Literal> elements);
-
- // Creates a new value that has the equivalent value as this literal, but
- // conforms to new_layout; e.g. a literal matrix that was in {0, 1}
- // minor-to-major dimension layout can be re-laid-out as {1, 0}
- // minor-to-major dimension layout and the value in the cell at any given
- // logical index (i0, i1) will be the same.
- //
- // For tuple shaped literals, shape_index should be used to select the inner
- // array that the new layout applies to.
- //
- // Note: this is useful when the client wants to ensure that a value placed in
- // the XLA allocation tracker has a particular layout; for efficiency
- // purposes or avoiding unimplemented operation/layout combinations.
- std::unique_ptr<Literal> Relayout(const Layout& new_layout,
- const ShapeIndex& shape_index = {}) const;
-
- // An overload of Relayout which changes the layout of the entire shape rather
- // than being limited to a single array within the shape.
- std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
-
- // Creates a new literal by reshaping this literal to have the given
- // dimensions. The total number of elements must not change; The
- // implementation currently only supports monotonic dim0-major layouts.
- // This literal must be an array.
- StatusOr<std::unique_ptr<Literal>> Reshape(
- tensorflow::gtl::ArraySlice<int64> dimensions) const;
-
- // Creates a new literal by reordering the dimensions of this literal.
- // The given `permutation` must be a permutation of the dimension numbers
- // in the original literal, and it specifies the order of the new dimensions
- // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
- // For example, a transpose call on a literal of shape [3 x 8 x 4] and
- // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
- // This literal must be an array.
- std::unique_ptr<Literal> Transpose(
- tensorflow::gtl::ArraySlice<int64> permutation) const;
-
- // Creates a sub-array from this literal by extracting the indices
- // [start_index, limit_index) of each dimension. The result literal has the
- // same rank and layout as for the given literal. The number of indices in
- // start_indices and limit_indices must be the rank of the literal, and the
- // indices follow the order of the dimensions.
- // This literal must be an array.
- std::unique_ptr<Literal> Slice(
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices) const;
-
- // Creates a literal with a prepended dimension with bound "times"; e.g. a
- // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
- // literal replicated four times.
- // This literal must be an array.
- template <typename NativeT>
- std::unique_ptr<Literal> Replicate(int64 times) const;
-
- // Converts this literal to another primitive type using
- // static_cast<>. Returns an error if the conversion is not possible. This
- // literal must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> Convert(
- PrimitiveType primitive_dest_type) const;
-
- // Converts this literal to another primitive type using a bitcast
- // conversion. The to and from primitive types must have the same bit
- // width. Returns an error if the conversion is not possible. This literal
- // must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> BitcastConvert(
- PrimitiveType primitive_dest_type) const;
-
- // Converts this literal to the given shape. Returns an error is the
- // conversion is not possible.
- //
- // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
- // instead of truncation; otherwise, truncation is used.
- //
- // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
- // the default behavior.
- StatusOr<std::unique_ptr<Literal>> ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
-
// Creates a scalar literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type);
-
// Creates a scalar literal value one of the given primitive type.
static Literal One(PrimitiveType primitive_type);
-
// Creates a scalar literal value containing the minimum value of the given
// primitive type. For floating-point types, returns -inf.
static Literal MinValue(PrimitiveType primitive_type);
-
// Creates a scalar literal value containing the maximum value of the given
// primitive type. For floating-point types, returns inf.
static Literal MaxValue(PrimitiveType primitive_type);
-
// Creates a literal of the given shape where each element is `value`.
template <typename NativeT>
static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
@@ -429,79 +856,6 @@ class Literal {
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z);
- // Clones this literal into a new Literal, or new std::unique_ptr<Literal>.
- Literal Clone() const;
- std::unique_ptr<Literal> CloneToUnique() const;
-
- // Gets or sets an element in the literal at the given index. The multi_index
- // is CHECKed against the dimension sizes.
- template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index) const;
- template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index, NativeT value);
-
- // Overloads of Get and Set for array literals. CHECKs if the literal is not
- // array-shaped and dense.
- template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
- template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
-
- // Returns the multi-index of the element in a sparse literal at the given
- // sparse element number. The sparse element number is the position with in
- // the sparse array's list of (index, value) pairs, and is checked against the
- // total number of (index, value) pairs in the sparse array.
- tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
- int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
-
- // Returns the value of the element in a sparse literal at the given sparse
- // element number. The sparse element number is the position with in the
- // sparse array's list of (index, value) pairs, and is checked against the
- // total number of (index, value) pairs in the sparse array.
- template <typename NativeT>
- NativeT GetSparseElement(int64 sparse_element_number,
- const ShapeIndex& shape_index = {}) const;
-
- // Appends the given element to the literal. If the elements are not appended
- // in sorted order, then SortSparseElements should be called before calling
- // other methods. This literal must have a sparse layout.
- template <typename NativeT>
- void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value, const ShapeIndex& shape_index = {});
-
- // Sorts the elements in a sparse array.
- void SortSparseElements(const ShapeIndex& shape_index = {});
-
- // Returns the element value at index (0, ..., 0), however many zeroes are
- // required for that index.
- template <typename NativeT>
- NativeT GetFirstElement() const;
-
- // Returns a literal scalar representing the first element.
- Literal GetFirstScalarLiteral() const;
-
- // As Get(), but determines the correct type and converts the value
- // into text.
- string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index = {}) const;
-
- // As GetSparseElement(), but determines the correct type and converts the
- // value into text.
- string GetSparseElementAsString(int64 sparse_element_number,
- const ShapeIndex& shape_index = {}) const;
-
- // As Get(), but determines the correct type and converts the value into
- // int64. This literal must be an array.
- StatusOr<int64> GetIntegralAsS64(
- tensorflow::gtl::ArraySlice<int64> multi_index) const;
-
- // As Set(), but truncates `value` to the literal element type before storing.
- // This literal must be an array.
- Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
- int64 value);
-
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
@@ -511,6 +865,9 @@ class Literal {
static std::unique_ptr<Literal> MakeTuple(
tensorflow::gtl::ArraySlice<const Literal*> elements);
+ static std::unique_ptr<Literal> MakeTupleFromSlices(
+ tensorflow::gtl::ArraySlice<LiteralSlice> elements);
+
// As above, but intended to be invoked with move semantics; i.e.
//
// std::vector<std::unique_ptr<Literal>> elements = ...;
@@ -542,135 +899,105 @@ class Literal {
return MakeTupleOwned(std::move(v));
}
- // Returns a string representation of the literal value.
- // Warning: this function can take minutes for multi-million element Literals.
- string ToString(bool print_layout = false) const;
-
- // Invokes the "per cell" callback for each element in the provided
- // literal with the element's indices and a string representation of
- // the element's value.
- //
- // This function is useful if you want a polymorphic representation
- // of the tensor's elements (turning it to a string for something
- // like representation in a protobuf).
- //
- // This literal must have a dense layout.
- void EachCellAsString(
- const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- const string& value)>& per_cell) const;
- template <typename NativeT>
- void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- NativeT value)>
- per_cell) const;
-
- // Populate this literal with the given values. Examples:
- //
- // // Populate with floats.
- // Array2D<float> float_values = ...
- // literal.PopulateR2FromArray2D(values);
- //
- // // Populate with int32s.
- // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
- //
- // The shape and element type of this literal must match given values. For
- // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
- // array of S32.
- template <typename NativeT>
- void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
- void PopulateR1(const tensorflow::core::Bitmap& values);
- template <typename NativeT>
- void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
- template <typename NativeT>
- void PopulateFromArray(const Array<NativeT>& values);
- template <typename NativeT>
- void PopulateR2FromArray2D(const Array2D<NativeT>& values);
- template <typename NativeT>
- void PopulateR3FromArray3D(const Array3D<NativeT>& values);
- template <typename NativeT>
- void PopulateR4FromArray4D(const Array4D<NativeT>& values);
-
- // Populates literal values by calling the generator function for every cell
- // in this literal object.
- //
- // generator must be a callable of the type
- // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
- //
- // This literal must have a dense layout.
- template <typename NativeT, typename FnType>
- Status Populate(const FnType& generator);
-
- // A parallel version of Populate(). This can be used if the generator is
- // thread-safe and the values for the shape's different elements are
- // independent.
- template <typename NativeT, typename FnType>
- Status PopulateParallel(const FnType& generator);
+ // Returns a vector containing the tuple elements of this Literal as separate
+ // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
+ // elements are moved into the new Literals; no data is copied. Upon return
+ // this Literal is set to a nil shape (empty tuple)
+ std::vector<Literal> DecomposeTuple();
- // Fills this literal with the given value.
- template <typename NativeT>
- void PopulateWithValue(NativeT value);
+ // This operation is the inverse of DecomposeTuple. The given elements are
+ // moved into the tuple elements of a new tuple-shaped Literal which is
+ // returned. Upon return, each of the Literals in 'elements' is set to a nil
+ // shape (empty tuple).
+ static Literal MoveIntoTuple(
+ tensorflow::gtl::MutableArraySlice<Literal> elements);
- // Returns whether every element in this literal is equal to value.
- //
- // value is an int8 because we expect this to be called with small
- // compile-time constants (0, -1, etc.) and so that whatever value you pass
- // can be represented exactly by floating-point types as small as 16 bits.
- //
- // If value doesn't fit in this literal's type, returns false. Values of 1/0
- // are considered equal to true/false; other values are not considered equal
- // to true. Also if this literal is not array-shaped false is returned.
- bool IsAll(int8 value) const;
+ // Creates a new Literal object with its values havings the primitive_type
+ // type, and with dimensions defined by the dimensions parameter.
+ // The content of the literal values is the default value of the primitive
+ // type of literal itself (0 for numeric types, and false for predicates).
+ static std::unique_ptr<Literal> CreateFromDimensions(
+ PrimitiveType primitive_type,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
- // Like IsAll(const Literal&, int8), except we check whether the literal is
- // equal to a particular floating-point number.
- //
- // If the literal is not a floating-point value, this always returns false.
- //
- // This casts value to the type of literal, then compares using ==. The usual
- // admonishments about floating-point equality checks apply. We expect you to
- // use this to check for values that can be expressed precisely as a float,
- // e.g. -0.5. Also if this literal is not array-shaped false is returned.
- bool IsAllFloat(float value) const;
+ // If the given literal's data type is bfloat16, converts it to a float
+ // literal; otherwise, returns a copy of it. If the literal is a tuple,
+ // recursively converts its elements.
+ static std::unique_ptr<Literal> ConvertBF16ToF32(
+ const LiteralSlice& bf16_literal);
+
+ // If the given literal's data type is float, converts it to a bfloat16
+ // literal; otherwise, returns a copy of it. If the literal is a tuple,
+ // recursively converts its elements.
+ static std::unique_ptr<Literal> ConvertF32ToBF16(
+ const LiteralSlice& f32_literal);
+
+ // Creates a literal with a new shape with the given new dimensions using the
+ // data in the given input literal. For reshaping purposes the (flat) data
+ // buffer of the input literal is assumed to have the given minor_to_major
+ // layout order.
+ static std::unique_ptr<Literal> ReshapeSlice(
+ tensorflow::gtl::ArraySlice<int64> new_dimensions,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major,
+ const LiteralSlice& literal);
+
+ // Creates a literal with the supplied shape, and uses the provided value
+ // generator to populate the literal's values.
+ // Returns the new literal object, or an error Status if failed.
+ template <
+ PrimitiveType type,
+ typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
+ static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+ const Shape& shape,
+ const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
+
+ // Creates a literal with the supplied shape, and initializes the literal
+ // values using a normal distribution with given mean and stddev standard
+ // deviation, and using the engine as entropy generator.
+ // Returns the new literal object, or an error Status if failed.
+ template <
+ PrimitiveType type, typename E,
+ typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
+ static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+ const Shape& shape, E* engine, T mean, T stddev);
+
+ // Creates a literal with the supplied shape, and initializes the literal
+ // values using a normal distribution with given mean and stddev standard
+ // deviation.
+ // Returns the new literal object, or an error Status if failed.
+ template <
+ PrimitiveType type,
+ typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
+ static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+ const Shape& shape, T mean, T stddev);
- // Like IsAll(const Literal&, int8), except we check whether the literal is
- // equal to a particular complex number.
//
- // If the literal is not a complex value, this always returns false.
- //
- // This casts value to the type of literal, then compares using ==. The usual
- // admonishments about floating-point equality checks apply. We expect you to
- // use this to check for complex values that can be expressed precisely as
- // float pairs e.g. (-0.5, 1.0).
- //
- // This literal must have a dense layout.
- bool IsAllComplex(complex64 value) const;
+ // End of factory methods.
- // Literal consists entirely of the first element of the literal.
- bool IsAllFirst() const;
+ // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
+ // be returned for a 2-dimensional index with dimension 0 index equal to 7,
+ // dimension 1 equal to 8.
+ static string MultiIndexAsString(
+ tensorflow::gtl::ArraySlice<int64> multi_index);
- // Returns whether this literal is zero at the specified index. This literal
- // must be an array with a dense layout.
- bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
+ protected:
+ // Recursively sets the subshapes and buffers of all subpieces rooted at
+ // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
+ // the shape.
+ void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
- // Return the count of the elements in the array at the given shape index in
- // this literal.
- int64 element_count(const ShapeIndex& index = {}) const {
- return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
+ // Returns the piece at the given ShapeIndex.
+ Piece& piece(const ShapeIndex& shape_index) {
+ return const_cast<Piece&>(LiteralBase::piece(shape_index));
}
- // Return the count of the elements in the sparse array at the given shape
- // index in this literal, which will be no larger than
- // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
- int64 sparse_element_count() const;
+ Piece& root_piece() const override { return *root_piece_; };
- // Compute a hash for this literal. This literal must not be a sparse tensor
- // or a tuple containing a sparse tensor.
- size_t Hash() const;
-
- protected:
+ private:
// Internal template helper for the Literal::CopySliceFrom(), matching its
// arguments one by one.
template <typename NativeT>
- Status CopySliceFromInternal(const Literal& src_literal,
+ Status CopySliceFromInternal(const LiteralBase& src_literal,
tensorflow::gtl::ArraySlice<int64> src_base,
tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size);
@@ -698,162 +1025,40 @@ class Literal {
int64 minor_loop_size = 1;
};
- // A data structure representing a subshape at a particular ShapeIndex within
- // the literal. For array-shaped ShapeIndexes, this data structure holds the
- // pointer to the memory allocated for the array data.
- class Piece {
- public:
- // Return the buffer holding the array data for this piece as an array
- // slice. This piece must be array-shaped.
- template <typename NativeT>
- tensorflow::gtl::ArraySlice<NativeT> data() const;
- template <typename NativeT>
- tensorflow::gtl::MutableArraySlice<NativeT> data();
-
- // Return the buffer holding the array data for this piece as a void*. This
- // piece must be array-shaped.
- void* untyped_data();
- const void* untyped_data() const;
+ // Literal class always owns the shape. The parent class borrows this shape.
+ std::unique_ptr<Shape> shape_;
- // Gets or sets an element in the array at the given index. The multi_index
- // is CHECKed against the dimension sizes of the array. This piece must be
- // array-shaped.
- template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
- template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
-
- // Gets/sets the buffer holding the array data.
- char* buffer() const { return buffer_; }
- void set_buffer(char* buffer) { buffer_ = buffer; }
-
- // The array of multi-indices that provide the locations of non-zero
- // elements in a sparse array. Only used if
- // LayoutUtil::IsSparseArray(shape()) is true.
- SparseIndexArray* sparse_indices() const { return sparse_indices_; }
- void set_sparse_indices(SparseIndexArray* sparse_indices) {
- sparse_indices_ = sparse_indices;
- }
-
- // Gets or sets the subshape of this piece. This reference points to a
- // subshape within the shape in the containing Literal (Literal::shape_).
- const Shape& subshape() const { return *subshape_; }
- void set_subshape(const Shape* subshape) { subshape_ = subshape; }
-
- // Returns the size in bytes of the buffer holding the array data.
- int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
-
- // Returns the number of elements in this piece's array.
- int64 element_count() const {
- // If this is a sparse array, use the number of elements represented by
- // the indices in the associated SparseIndexArray.
- return LayoutUtil::IsSparseArray(subshape())
- ? sparse_indices()->index_count()
- : ShapeUtil::ElementsIn(subshape());
- }
-
- // Copy the data from 'src' into this piece's buffer. Shapes of this piece
- // and src must be compatible.
- Status CopyFrom(const Piece& src);
-
- // Returns true if this piece and 'other' contain the same data. This piece
- // and 'other' must be array-shaped and compatible.
- bool EqualElements(const Piece& other) const;
-
- // Writes the shape and data (if array-shaped) into the given proto.
- void WriteToProto(LiteralProto* proto) const;
-
- // Copies the data from the given proto into this piece. The shape of this
- // piece must be equal (not just compatible) to the shape of the proto.
- Status CopyFromProto(const LiteralProto& proto);
-
- // Sorts the elements in a sparse array.
- void SortSparseElements();
-
- private:
- // Recursive helper for EqualElements.
- template <typename NativeT>
- bool EqualElementsInternal(const Piece& other,
- std::vector<int64>* multi_index) const;
-
- // Helper for SortSparseElements that has the element type as a template
- // parameter.
- template <typename NativeT>
- void SortSparseElementsInternal();
-
- // For array-shaped pieces, this is the buffer holding the literal data.
- char* buffer_ = nullptr;
-
- // For sparse arrays, this is the array of indices.
- SparseIndexArray* sparse_indices_ = nullptr;
-
- // The shape of piece. This points into the shape of the containing Literal
- // (Literal::shape_).
- const Shape* subshape_ = nullptr;
- };
-
- // Returns the piece at the given ShapeIndex.
- Piece& piece(const ShapeIndex& shape_index) {
- return *pieces_.mutable_element(shape_index);
- }
- const Piece& piece(const ShapeIndex& shape_index) const {
- return pieces_.element(shape_index);
- }
-
- // Returns the piece at the root of the shape (empty ShapeIndex).
- Piece& root_piece() { return piece({}); }
- const Piece& root_piece() const { return piece({}); }
-
- // Deallocate the buffers held by this literal (if the literal owns the
- // buffer).
- void DeallocateBuffers();
+ Piece* root_piece_ = nullptr;
// Implementation details shared between Populate() and PopulateParallel()
template <typename NativeT, typename FnType>
Status PopulateInternal(const FnType& generator, bool parallel);
- Shape shape_;
- ShapeTree<Piece> pieces_;
-
- // Whether the buffers held in pieces_ are owned by this Literal.
- bool owns_buffers_;
+ // Deallocate the buffers held by this literal.
+ void DeallocateBuffers();
- // LiteralView must access and manipulate Pieces of other Literals.
- friend class LiteralView;
-}; // namespace xla
+ friend class LiteralBase;
+};
std::ostream& operator<<(std::ostream& out, const Literal& literal);
-// A read-only view of a Literal. A LiteralView contains pointers to buffers
-// owned by the viewed Literal.
-//
-// TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable
-// and mutable) similar to (Mutable)ArraySlice.
-class LiteralView : public Literal {
+// A read-only view of a Literal. A LiteralSlice contains pointers to shape and
+// literal buffers always owned by others.
+class LiteralSlice : public LiteralBase {
public:
- // Create and return a view of the given literal rooted at the given shape
- // index within the given literal. A factory is used rather than a public
- // constructor because only const LiteralViews are supported. It's still
- // possible to create non-const LiteralViews via the copy constructors, but
- // the factory method makes it a bit less likely. Implementing literal slices
- // will fix this undesirable situation (b/71550060).
- static const LiteralView Create(const Literal& literal,
- const ShapeIndex& view_root = {});
-
- LiteralView(const LiteralView& other);
- LiteralView& operator=(const LiteralView& other);
-
- virtual ~LiteralView();
+ LiteralSlice() : LiteralBase() {}
+ // Implicit conversion constructor that can also accept Literal.
+ LiteralSlice(const LiteralBase& literal);
+ LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
private:
- LiteralView(const Literal& literal, const ShapeIndex& view_root);
+ const Piece& root_piece() const override { return *root_piece_; };
- // Helper for the copy constructor and copy assignment operator.
- void CopyFrom(const LiteralView& other);
+ const Piece* root_piece_; // Not owned.
};
template <typename NativeT>
-tensorflow::gtl::ArraySlice<NativeT> Literal::Piece::data() const {
+tensorflow::gtl::ArraySlice<NativeT> LiteralBase::Piece::data() const {
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
CHECK_EQ(subshape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>())
@@ -866,7 +1071,7 @@ tensorflow::gtl::ArraySlice<NativeT> Literal::Piece::data() const {
}
template <typename NativeT>
-tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() {
+tensorflow::gtl::MutableArraySlice<NativeT> LiteralBase::Piece::data() {
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
CHECK_EQ(subshape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>())
@@ -879,7 +1084,7 @@ tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() {
}
template <typename NativeT>
-NativeT Literal::Piece::Get(
+NativeT LiteralBase::Piece::Get(
tensorflow::gtl::ArraySlice<int64> multi_index) const {
CHECK(LayoutUtil::IsDenseArray(subshape()));
return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
@@ -887,15 +1092,15 @@ NativeT Literal::Piece::Get(
}
template <typename NativeT>
-void Literal::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value) {
+void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value) {
CHECK(LayoutUtil::IsDenseArray(subshape()));
data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
subshape(), multi_index)] = value;
}
template <typename NativeT>
-tensorflow::gtl::ArraySlice<NativeT> Literal::data(
+tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
const ShapeIndex& shape_index) const {
return piece(shape_index).data<NativeT>();
}
@@ -907,13 +1112,13 @@ tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
}
template <typename NativeT>
-inline NativeT Literal::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index) const {
+inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const {
return piece(shape_index).Get<NativeT>(multi_index);
}
template <typename NativeT>
-inline NativeT Literal::Get(
+inline NativeT LiteralBase::Get(
tensorflow::gtl::ArraySlice<int64> multi_index) const {
return root_piece().Get<NativeT>(multi_index);
}
@@ -1160,13 +1365,13 @@ template <typename NativeT>
}
template <typename NativeT>
-NativeT Literal::GetFirstElement() const {
+NativeT LiteralBase::GetFirstElement() const {
return data<NativeT>().at(0);
}
template <typename NativeT>
-NativeT Literal::GetSparseElement(int64 sparse_element_number,
- const ShapeIndex& shape_index) const {
+NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
+ const ShapeIndex& shape_index) const {
CHECK(
LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
return data<NativeT>(shape_index)[sparse_element_number];
@@ -1199,7 +1404,7 @@ template <typename NativeT>
}
template <typename NativeT>
-void Literal::EachCell(
+void LiteralBase::EachCell(
std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
NativeT value)>
per_cell) const {
@@ -1375,7 +1580,7 @@ template <typename NativeT>
}
template <typename NativeT>
-std::unique_ptr<Literal> Literal::Replicate(int64 times) const {
+std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
DimensionVector bounds = {times};
bounds.reserve(shape().dimensions_size() + 1);
for (int64 bound : shape().dimensions()) {
@@ -1410,6 +1615,38 @@ std::unique_ptr<Literal> Literal::Replicate(int64 times) const {
return literal;
}
+template <PrimitiveType type, typename T>
+/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
+ const Shape& shape,
+ const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
+ using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
+ TF_RET_CHECK(shape.element_type() == type);
+ std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
+ TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
+ [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ return generator(indexes);
+ }));
+ return std::move(literal);
+}
+
+template <PrimitiveType type, typename E, typename T>
+/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
+ const Shape& shape, E* engine, T mean, T stddev) {
+ using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
+ std::normal_distribution<NativeT> generator(mean, stddev);
+ return CreateRandomLiteral<type, NativeT>(
+ shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
+ return generator(*engine);
+ });
+}
+
+template <PrimitiveType type, typename T>
+/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
+ const Shape& shape, T mean, T stddev) {
+ std::minstd_rand0 engine;
+ return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
+}
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index 61046784e0..087d509f28 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -974,7 +974,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) {
Literal::CreateR1<double>({2.0, 4.0}).get(),
&nil_literal});
- EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0}));
+ EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42);
EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 23.0);
EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 44.0);
@@ -985,7 +985,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) {
/*src_shape_index=*/{}));
// The matrix element should be unchanged.
- EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0}));
+ EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
// The tuple element should have been copied from 'tuple'.
EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), -5);
@@ -1373,36 +1373,36 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
ASSERT_EQ(h1, r[3]);
}
-TEST_F(LiteralUtilTest, LiteralViewTest) {
+TEST_F(LiteralUtilTest, LiteralSliceTest) {
auto scalar = Literal::CreateR0<float>(1.0);
auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
Literal nil(ShapeUtil::MakeNil());
- EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar);
- EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix);
- EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple);
- EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple);
- EXPECT_EQ(LiteralView::Create(nil, {}), nil);
+ EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar);
+ EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix);
+ EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple);
+ EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple);
+ EXPECT_EQ(LiteralSlice(nil, {}), nil);
- EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar);
- EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix);
+ EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar);
+ EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix);
- EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple);
- EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar);
- EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix);
- EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar);
+ EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple);
+ EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar);
+ EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix);
+ EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar);
}
-TEST_F(LiteralUtilTest, MutatingLiteralView) {
+TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
auto scalar = Literal::CreateR0<float>(1.0);
auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
// Verify that changing the underlying data beneath the view changes the
// data of the view itself.
- const auto nested_tuple_view = LiteralView::Create(*nested_tuple);
+ const auto nested_tuple_view = LiteralSlice(*nested_tuple);
EXPECT_EQ(
nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
1.0f);
@@ -1418,16 +1418,15 @@ TEST_F(LiteralUtilTest, MutatingLiteralView) {
555.0f);
}
-TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) {
+TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
auto scalar = Literal::CreateR0<float>(1.0);
auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
- const auto nested_tuple_view = LiteralView::Create(*nested_tuple);
- const auto tuple_view =
- LiteralView::Create(nested_tuple_view, /*view_root=*/{0});
- const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1});
+ const auto nested_tuple_view = LiteralSlice(*nested_tuple);
+ const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
+ const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
EXPECT_EQ(matrix_view, *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
}
@@ -1533,11 +1532,11 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
}
-TEST_F(LiteralUtilTest, LiteralViewCopy) {
+TEST_F(LiteralUtilTest, LiteralSliceCopy) {
std::unique_ptr<Literal> matrix =
Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- const auto matrix_view = LiteralView::Create(*matrix);
- LiteralView matrix_view_copy(matrix_view);
+ const auto matrix_view = LiteralSlice(*matrix);
+ LiteralSlice matrix_view_copy(matrix_view);
EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
EXPECT_EQ(matrix_view_copy.Get<float>({0, 1}), 2.0);
diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h
index 8db8c6f3de..3c74e070da 100644
--- a/tensorflow/compiler/xla/map_util.h
+++ b/tensorflow/compiler/xla/map_util.h
@@ -86,11 +86,10 @@ const typename Collection::value_type::second_type& FindOrDefault(
// Inserts the key-value pair into the collection. Dies if key was already
// present.
-template <class Collection>
-void InsertOrDie(Collection* const collection,
- const typename Collection::value_type::first_type& key,
- const typename Collection::value_type::second_type& data) {
- auto p = collection->insert(std::make_pair(key, data));
+template <class Collection, class Key, class Value>
+void InsertOrDie(Collection* const collection, Key&& key, Value&& value) {
+ auto p = collection->insert(
+ std::make_pair(std::forward<Key>(key), std::forward<Value>(value)));
CHECK(p.second) << "duplicate key: " << key;
}
@@ -101,9 +100,10 @@ bool ContainsKey(const Collection& collection, const Key& key) {
}
// Inserts `value` into `set`. Dies if it was already present.
-template <class Set>
-void InsertOrDie(Set* const set, const typename Set::value_type& value) {
- CHECK(set->insert(value).second) << "duplicate value: " << value;
+template <class Set, class Value>
+void InsertOrDie(Set* const set, Value&& value) {
+ CHECK(set->insert(std::forward<Value>(value)).second)
+ << "duplicate value: " << value;
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc
index dc6f5fe5fc..68648a3a17 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.cc
+++ b/tensorflow/compiler/xla/python/numpy_bridge.cc
@@ -340,13 +340,13 @@ StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o) {
return result;
}
-PyObject* PyObjectFromXlaLiteral(const Literal& literal) {
+PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
if (ShapeUtil::IsTuple(literal.shape())) {
int num_elements = ShapeUtil::TupleElementCount(literal.shape());
PyObject* tuple = PyTuple_New(num_elements);
for (int i = 0; i < num_elements; i++) {
- PyTuple_SET_ITEM(
- tuple, i, PyObjectFromXlaLiteral(LiteralView::Create(literal, {i})));
+ PyTuple_SET_ITEM(tuple, i,
+ PyObjectFromXlaLiteral(LiteralSlice(literal, {i})));
}
return tuple;
} else {
@@ -431,7 +431,7 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
return Status::OK();
}
-void CopyLiteralToNumpyArray(int np_type, const Literal& literal,
+void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
PyArrayObject* py_array) {
switch (np_type) {
case NPY_BOOL:
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h
index 9656cb1c31..64f0aae0f9 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.h
+++ b/tensorflow/compiler/xla/python/numpy_bridge.h
@@ -74,7 +74,7 @@ StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o);
// array data.
//
// The return value is a new reference.
-PyObject* PyObjectFromXlaLiteral(const Literal& literal);
+PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal);
// Converts a Numpy ndarray or a nested Python tuple thereof to a
// corresponding XLA literal.
@@ -90,7 +90,7 @@ StatusOr<std::unique_ptr<Literal> > XlaLiteralFromPyObject(PyObject* o);
Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
Literal* literal);
-void CopyLiteralToNumpyArray(int np_type, const Literal& literal,
+void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
PyArrayObject* py_array);
template <typename NativeT>
@@ -101,7 +101,8 @@ void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) {
}
template <typename NativeT>
-void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) {
+void CopyLiteralToNumpyArray(const LiteralSlice& literal,
+ PyArrayObject* py_array) {
NativeT* dest = static_cast<NativeT*>(PyArray_DATA(py_array));
auto source = literal.data<NativeT>();
std::copy(source.begin(), source.end(), dest);
diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
index 10997c0719..313f11a9a9 100644
--- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
@@ -101,8 +101,8 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) {
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
computation, {}, nullptr));
- LiteralTestUtil::ExpectNear(*expected_literal, *result_literal,
- ErrorSpec(0.0001));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal,
+ ErrorSpec(0.0001)));
}
} // namespace
diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc
index ffb72fc73c..5f4dc6bd08 100644
--- a/tensorflow/compiler/xla/rpc/grpc_service.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_service.cc
@@ -27,8 +27,8 @@ namespace xla {
return std::move(grpc_service);
}
-::grpc::Status DelegateRPC(std::function<tensorflow::Status()> op) {
- tensorflow::Status s = op();
+::grpc::Status DelegateRPC(std::function<Status()> op) {
+ Status s = op();
return tensorflow::ToGrpcStatus(s);
}
diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc
index e1f2b0abe3..620ac6cec4 100644
--- a/tensorflow/compiler/xla/rpc/grpc_stub.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc
@@ -20,53 +20,49 @@ namespace xla {
GRPCStub::~GRPCStub() = default;
-tensorflow::Status MakeRPC(
+Status MakeRPC(
const std::function<::grpc::Status(::grpc::ClientContext*)>& rpc_method) {
::grpc::ClientContext context;
::grpc::Status s = rpc_method(&context);
return tensorflow::FromGrpcStatus(s);
}
-tensorflow::Status GRPCStub::TransferToClient(
- const TransferToClientRequest* request,
- TransferToClientResponse* response) {
+Status GRPCStub::TransferToClient(const TransferToClientRequest* request,
+ TransferToClientResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->TransferToClient(context, *request, response);
});
}
-tensorflow::Status GRPCStub::TransferToServer(
- const TransferToServerRequest* request,
- TransferToServerResponse* response) {
+Status GRPCStub::TransferToServer(const TransferToServerRequest* request,
+ TransferToServerResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->TransferToServer(context, *request, response);
});
}
-tensorflow::Status GRPCStub::TransferToInfeed(
- const TransferToInfeedRequest* request,
- TransferToInfeedResponse* response) {
+Status GRPCStub::TransferToInfeed(const TransferToInfeedRequest* request,
+ TransferToInfeedResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->TransferToInfeed(context, *request, response);
});
}
-tensorflow::Status GRPCStub::TransferFromOutfeed(
- const TransferFromOutfeedRequest* request,
- TransferFromOutfeedResponse* response) {
+Status GRPCStub::TransferFromOutfeed(const TransferFromOutfeedRequest* request,
+ TransferFromOutfeedResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->TransferFromOutfeed(context, *request, response);
});
}
-tensorflow::Status GRPCStub::ResetDevice(const ResetDeviceRequest* request,
- ResetDeviceResponse* response) {
+Status GRPCStub::ResetDevice(const ResetDeviceRequest* request,
+ ResetDeviceResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->ResetDevice(context, *request, response);
});
}
-tensorflow::Status GRPCStub::LoadComputationSnapshot(
+Status GRPCStub::LoadComputationSnapshot(
const LoadComputationSnapshotRequest* request,
LoadComputationSnapshotResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
@@ -74,28 +70,28 @@ tensorflow::Status GRPCStub::LoadComputationSnapshot(
});
}
-tensorflow::Status GRPCStub::Execute(const ExecuteRequest* request,
- ExecuteResponse* response) {
+Status GRPCStub::Execute(const ExecuteRequest* request,
+ ExecuteResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->Execute(context, *request, response);
});
}
-tensorflow::Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request,
- ExecuteResponse* response) {
+Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request,
+ ExecuteResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->ExecuteGraph(context, *request, response);
});
}
-tensorflow::Status GRPCStub::ExecuteParallel(
- const ExecuteParallelRequest* request, ExecuteParallelResponse* response) {
+Status GRPCStub::ExecuteParallel(const ExecuteParallelRequest* request,
+ ExecuteParallelResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->ExecuteParallel(context, *request, response);
});
}
-tensorflow::Status GRPCStub::ExecuteGraphParallel(
+Status GRPCStub::ExecuteGraphParallel(
const ExecuteGraphParallelRequest* request,
ExecuteParallelResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
@@ -103,38 +99,35 @@ tensorflow::Status GRPCStub::ExecuteGraphParallel(
});
}
-tensorflow::Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request,
- ExecuteAsyncResponse* response) {
+Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request,
+ ExecuteAsyncResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->ExecuteAsync(context, *request, response);
});
}
-tensorflow::Status GRPCStub::WaitForExecution(
- const WaitForExecutionRequest* request,
- WaitForExecutionResponse* response) {
+Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request,
+ WaitForExecutionResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->WaitForExecution(context, *request, response);
});
}
-tensorflow::Status GRPCStub::DeconstructTuple(
- const DeconstructTupleRequest* request,
- DeconstructTupleResponse* response) {
+Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request,
+ DeconstructTupleResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->DeconstructTuple(context, *request, response);
});
}
-tensorflow::Status GRPCStub::GetComputationStats(
- const ComputationStatsRequest* request,
- ComputationStatsResponse* response) {
+Status GRPCStub::GetComputationStats(const ComputationStatsRequest* request,
+ ComputationStatsResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->GetComputationStats(context, *request, response);
});
}
-tensorflow::Status GRPCStub::GetComputationGraphStats(
+Status GRPCStub::GetComputationGraphStats(
const ComputationGraphStatsRequest* request,
ComputationStatsResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
@@ -142,81 +135,77 @@ tensorflow::Status GRPCStub::GetComputationGraphStats(
});
}
-tensorflow::Status GRPCStub::GetComputationShape(
- const GetComputationShapeRequest* request,
- GetComputationShapeResponse* response) {
+Status GRPCStub::GetComputationShape(const GetComputationShapeRequest* request,
+ GetComputationShapeResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->GetComputationShape(context, *request, response);
});
}
-tensorflow::Status GRPCStub::GetShape(const GetShapeRequest* request,
- GetShapeResponse* response) {
+Status GRPCStub::GetShape(const GetShapeRequest* request,
+ GetShapeResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->GetShape(context, *request, response);
});
}
-tensorflow::Status GRPCStub::GetDeviceHandles(
- const GetDeviceHandlesRequest* request,
- GetDeviceHandlesResponse* response) {
+Status GRPCStub::GetDeviceHandles(const GetDeviceHandlesRequest* request,
+ GetDeviceHandlesResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->GetDeviceHandles(context, *request, response);
});
}
-tensorflow::Status GRPCStub::CreateChannelHandle(
- const CreateChannelHandleRequest* request,
- CreateChannelHandleResponse* response) {
+Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request,
+ CreateChannelHandleResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->CreateChannelHandle(context, *request, response);
});
}
// Methods used by ComputationBuilder.
-tensorflow::Status GRPCStub::Computation(const ComputationRequest* request,
- ComputationResponse* response) {
+Status GRPCStub::Computation(const ComputationRequest* request,
+ ComputationResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->Computation(context, *request, response);
});
}
-tensorflow::Status GRPCStub::Op(const OpRequest* request,
- OpResponse* response) {
+Status GRPCStub::Op(const OpRequest* request, OpResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->CreateOp(context, *request, response);
});
}
-tensorflow::Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request,
- GetLocalShapeResponse* response) {
+Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request,
+ GetLocalShapeResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->GetLocalShape(context, *request, response);
});
}
-tensorflow::Status GRPCStub::SetReturnValue(
- const SetReturnValueRequest* request, SetReturnValueResponse* responses) {
+Status GRPCStub::SetReturnValue(const SetReturnValueRequest* request,
+ SetReturnValueResponse* responses) {
return MakeRPC([this, request, responses](::grpc::ClientContext* context) {
return grpc_stub_->SetReturnValue(context, *request, responses);
});
}
-tensorflow::Status GRPCStub::IsConstant(const IsConstantRequest* request,
- IsConstantResponse* response) {
+Status GRPCStub::IsConstant(const IsConstantRequest* request,
+ IsConstantResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->IsConstant(context, *request, response);
});
}
-tensorflow::Status GRPCStub::ComputeConstant(
- const ComputeConstantRequest* request, ComputeConstantResponse* response) {
+Status GRPCStub::ComputeConstant(const ComputeConstantRequest* request,
+ ComputeConstantResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->ComputeConstant(context, *request, response);
});
}
-tensorflow::Status GRPCStub::ComputeConstantGraph(
+Status GRPCStub::ComputeConstantGraph(
const ComputeConstantGraphRequest* request,
ComputeConstantResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
@@ -225,17 +214,16 @@ tensorflow::Status GRPCStub::ComputeConstantGraph(
}
// Methods used by Computation.
-tensorflow::Status GRPCStub::SnapshotComputation(
- const SnapshotComputationRequest* request,
- SnapshotComputationResponse* response) {
+Status GRPCStub::SnapshotComputation(const SnapshotComputationRequest* request,
+ SnapshotComputationResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->SnapshotComputation(context, *request, response);
});
}
// Methods used by GlobalData.
-tensorflow::Status GRPCStub::Unregister(const UnregisterRequest* request,
- UnregisterResponse* response) {
+Status GRPCStub::Unregister(const UnregisterRequest* request,
+ UnregisterResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->Unregister(context, *request, response);
});
diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h
index fd9810d4f1..5906d45769 100644
--- a/tensorflow/compiler/xla/rpc/grpc_stub.h
+++ b/tensorflow/compiler/xla/rpc/grpc_stub.h
@@ -28,105 +28,90 @@ class GRPCStub : public ServiceInterface {
explicit GRPCStub(grpc::XlaService::Stub* stub) : grpc_stub_(stub) {}
~GRPCStub() override;
- tensorflow::Status TransferToClient(
- const TransferToClientRequest* arg,
- TransferToClientResponse* result) override;
+ Status TransferToClient(const TransferToClientRequest* arg,
+ TransferToClientResponse* result) override;
- tensorflow::Status TransferToServer(
- const TransferToServerRequest* arg,
- TransferToServerResponse* result) override;
+ Status TransferToServer(const TransferToServerRequest* arg,
+ TransferToServerResponse* result) override;
- tensorflow::Status TransferToInfeed(
- const TransferToInfeedRequest* arg,
- TransferToInfeedResponse* result) override;
+ Status TransferToInfeed(const TransferToInfeedRequest* arg,
+ TransferToInfeedResponse* result) override;
- tensorflow::Status TransferFromOutfeed(
- const TransferFromOutfeedRequest* arg,
- TransferFromOutfeedResponse* result) override;
+ Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
+ TransferFromOutfeedResponse* result) override;
- tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
- ResetDeviceResponse* result) override;
+ Status ResetDevice(const ResetDeviceRequest* arg,
+ ResetDeviceResponse* result) override;
- tensorflow::Status LoadComputationSnapshot(
+ Status LoadComputationSnapshot(
const LoadComputationSnapshotRequest* request,
LoadComputationSnapshotResponse* result) override;
- tensorflow::Status Execute(const ExecuteRequest* arg,
- ExecuteResponse* result) override;
+ Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override;
- tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* request,
- ExecuteResponse* response) override;
+ Status ExecuteGraph(const ExecuteGraphRequest* request,
+ ExecuteResponse* response) override;
- tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg,
- ExecuteParallelResponse* result) override;
+ Status ExecuteParallel(const ExecuteParallelRequest* arg,
+ ExecuteParallelResponse* result) override;
- tensorflow::Status ExecuteGraphParallel(
- const ExecuteGraphParallelRequest* request,
- ExecuteParallelResponse* response) override;
+ Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request,
+ ExecuteParallelResponse* response) override;
- tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg,
- ExecuteAsyncResponse* result) override;
+ Status ExecuteAsync(const ExecuteAsyncRequest* arg,
+ ExecuteAsyncResponse* result) override;
- tensorflow::Status WaitForExecution(
- const WaitForExecutionRequest* arg,
- WaitForExecutionResponse* result) override;
+ Status WaitForExecution(const WaitForExecutionRequest* arg,
+ WaitForExecutionResponse* result) override;
- tensorflow::Status DeconstructTuple(
- const DeconstructTupleRequest* arg,
- DeconstructTupleResponse* result) override;
+ Status DeconstructTuple(const DeconstructTupleRequest* arg,
+ DeconstructTupleResponse* result) override;
- tensorflow::Status GetComputationStats(
- const ComputationStatsRequest* arg,
- ComputationStatsResponse* result) override;
+ Status GetComputationStats(const ComputationStatsRequest* arg,
+ ComputationStatsResponse* result) override;
- tensorflow::Status GetComputationGraphStats(
- const ComputationGraphStatsRequest* request,
- ComputationStatsResponse* response) override;
+ Status GetComputationGraphStats(const ComputationGraphStatsRequest* request,
+ ComputationStatsResponse* response) override;
- tensorflow::Status GetComputationShape(
- const GetComputationShapeRequest* arg,
- GetComputationShapeResponse* result) override;
+ Status GetComputationShape(const GetComputationShapeRequest* arg,
+ GetComputationShapeResponse* result) override;
- tensorflow::Status GetShape(const GetShapeRequest* arg,
- GetShapeResponse* result) override;
+ Status GetShape(const GetShapeRequest* arg,
+ GetShapeResponse* result) override;
- tensorflow::Status GetDeviceHandles(
- const GetDeviceHandlesRequest* arg,
- GetDeviceHandlesResponse* result) override;
+ Status GetDeviceHandles(const GetDeviceHandlesRequest* arg,
+ GetDeviceHandlesResponse* result) override;
- tensorflow::Status CreateChannelHandle(
- const CreateChannelHandleRequest* arg,
- CreateChannelHandleResponse* result) override;
+ Status CreateChannelHandle(const CreateChannelHandleRequest* arg,
+ CreateChannelHandleResponse* result) override;
// Methods used by ComputationBuilder.
- tensorflow::Status Computation(const ComputationRequest* arg,
- ComputationResponse* result) override;
+ Status Computation(const ComputationRequest* arg,
+ ComputationResponse* result) override;
- tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override;
- tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg,
- GetLocalShapeResponse* result) override;
+ Status Op(const OpRequest* arg, OpResponse* result) override;
+ Status GetLocalShape(const GetLocalShapeRequest* arg,
+ GetLocalShapeResponse* result) override;
- tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg,
- SetReturnValueResponse* results) override;
+ Status SetReturnValue(const SetReturnValueRequest* arg,
+ SetReturnValueResponse* results) override;
- tensorflow::Status IsConstant(const IsConstantRequest* arg,
- IsConstantResponse* result) override;
+ Status IsConstant(const IsConstantRequest* arg,
+ IsConstantResponse* result) override;
- tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg,
- ComputeConstantResponse* result) override;
+ Status ComputeConstant(const ComputeConstantRequest* arg,
+ ComputeConstantResponse* result) override;
- tensorflow::Status ComputeConstantGraph(
- const ComputeConstantGraphRequest* arg,
- ComputeConstantResponse* result) override;
+ Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
+ ComputeConstantResponse* result) override;
// Methods used by Computation.
- tensorflow::Status SnapshotComputation(
- const SnapshotComputationRequest* ag,
- SnapshotComputationResponse* result) override;
+ Status SnapshotComputation(const SnapshotComputationRequest* ag,
+ SnapshotComputationResponse* result) override;
// Methods used by GlobalData.
- tensorflow::Status Unregister(const UnregisterRequest* arg,
- UnregisterResponse* result) override;
+ Status Unregister(const UnregisterRequest* arg,
+ UnregisterResponse* result) override;
grpc::XlaService::Stub* service() { return grpc_stub_; }
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index aa3a6261e0..04a9a4a887 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -12,6 +12,7 @@ package_group(
],
)
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
@@ -1010,6 +1011,7 @@ cc_library(
],
deps = [
":buffer_liveness",
+ ":buffer_value_containers",
":heap_simulator",
":hlo",
":hlo_proto",
@@ -1098,11 +1100,12 @@ cc_library(
srcs = ["heap_simulator.cc"],
hdrs = ["heap_simulator.h"],
deps = [
+ ":buffer_value",
+ ":buffer_value_containers",
":hlo",
":hlo_ordering",
":hlo_proto",
":liveness_util",
- ":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
@@ -1118,7 +1121,7 @@ tf_cc_test(
":heap_simulator",
":hlo",
":hlo_ordering",
- ":logical_buffer",
+ ":hlo_value",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1268,13 +1271,11 @@ cc_library(
deps = [
":hlo",
":hlo_pass",
- ":hlo_query",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
],
@@ -1359,6 +1360,48 @@ tf_cc_test(
],
)
+cc_library(
+ name = "batch_dot_simplification",
+ srcs = ["batch_dot_simplification.cc"],
+ hdrs = ["batch_dot_simplification.h"],
+ deps = [
+ ":hlo",
+ ":hlo_creation_utils",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "batch_dot_simplification_test",
+ srcs = ["batch_dot_simplification_test.cc"],
+ deps = [
+ ":batch_dot_simplification",
+ ":hlo",
+ ":hlo_matchers",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
tf_cc_test(
name = "gather_expander_test",
srcs = ["gather_expander_test.cc"],
@@ -1721,6 +1764,7 @@ tf_cc_test(
":hlo_execution_profile",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
],
)
@@ -1786,6 +1830,17 @@ cc_library(
)
cc_library(
+ name = "buffer_value_containers",
+ hdrs = ["buffer_value_containers.h"],
+ deps = [
+ ":buffer_value",
+ ":logical_buffer",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
name = "logical_buffer",
srcs = ["logical_buffer.cc"],
hdrs = ["logical_buffer.h"],
@@ -2316,8 +2371,14 @@ tf_cc_test(
cc_library(
name = "device_memory_allocator",
- srcs = ["device_memory_allocator.cc"],
- hdrs = ["device_memory_allocator.h"],
+ srcs = [
+ "device_memory_allocator.cc",
+ "owning_device_memory.cc",
+ ],
+ hdrs = [
+ "device_memory_allocator.h",
+ "owning_device_memory.h",
+ ],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -2352,6 +2413,24 @@ cc_library(
],
)
+xla_test(
+ name = "elemental_ir_emitter_test",
+ srcs = ["elemental_ir_emitter_test.cc"],
+ backends = [
+ "cpu",
+ "gpu",
+ ],
+ deps = [
+ "//tensorflow/compiler/xla:execution_options_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ ],
+)
+
cc_library(
name = "hlo_module_config",
srcs = ["hlo_module_config.cc"],
@@ -2529,7 +2608,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 4ec79a0244..f732ed8f39 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -92,26 +92,6 @@ bool ReshapeIsBitcast(
valid_bitcast_callback(operand->shape(), reshape->shape());
}
-// Adds a scalar computation to the module to enable optimizations with dot
-// converting into reduction.
-HloComputation* CreateScalarBinaryComputation(HloModule* module,
- PrimitiveType primitive_type,
- HloOpcode opcode) {
- HloComputation::Builder b("scalar_computation");
- auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs"));
- auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs"));
- auto scalar_op = b.AddInstruction(
- HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}),
- opcode, scalar_lhs, scalar_rhs));
- HloComputation* scalar_computation =
- module->AddEmbeddedComputation(b.Build(scalar_op));
- return scalar_computation;
-}
-
-} // namespace
-
// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
// algebraic expressions to simplified forms. Note: This only supports
// simplifications that simply look at the operands of an instruction. For the
@@ -220,8 +200,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
HloInstruction* zero = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
- HloComputation* AddReduce_computation = CreateScalarBinaryComputation(
- computation_->parent(), F32, HloOpcode::kAdd);
+ HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
return computation_->AddInstruction(HloInstruction::CreateReduce(
shape, hlo, zero, {dim}, AddReduce_computation));
@@ -293,6 +272,24 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);
+ HloComputation* GetOrCreateScalarAddComputation() {
+ if (scalar_add_computation_) {
+ return scalar_add_computation_;
+ }
+
+ HloComputation::Builder b("scalar_add_computation");
+ Shape shape = ShapeUtil::MakeShape(F32, {});
+ auto scalar_lhs = b.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
+ auto scalar_rhs = b.AddInstruction(
+ HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
+ auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
+ scalar_add_computation_ =
+ computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
+ return scalar_add_computation_;
+ }
+
// Current HloComputation instance the AlgebraicSimplifierVisitor is
// traversing.
HloComputation* computation_;
@@ -311,8 +308,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Disable convolution simplification on platforms where it causes a slowdown.
bool enable_conv_simplification_;
+
+ // Cached computation for adding two scalar F32.
+ HloComputation* scalar_add_computation_ = nullptr;
};
+} // namespace
+
bool AlgebraicSimplifierVisitor::Run(
HloComputation* computation, bool is_layout_sensitive,
AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback,
@@ -501,13 +503,13 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
}
static HloInstruction* BuildTupleConstant(HloComputation* computation,
- const Literal& literal) {
+ const LiteralSlice& literal) {
if (ShapeUtil::IsTuple(literal.shape())) {
std::vector<HloInstruction*> elems;
elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
elems.push_back(
- BuildTupleConstant(computation, LiteralView::Create(literal, {i})));
+ BuildTupleConstant(computation, LiteralSlice(literal, {i})));
}
return computation->AddInstruction(HloInstruction::CreateTuple(elems));
} else {
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc
index cf1231bcce..95b4cb6d2e 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.cc
+++ b/tensorflow/compiler/xla/service/allocation_tracker.cc
@@ -101,7 +101,7 @@ StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
return result;
}
-tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
+Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
tensorflow::mutex_lock lock(mutex_);
VLOG(2) << "Unregister("
<< "handle: " << data.handle() << ")";
@@ -130,7 +130,7 @@ tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
for (auto& shaped_buffer : it->second) {
shaped_buffer.reset();
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
@@ -220,8 +220,10 @@ void AllocationTracker::AddAllocationOrIncrementRefCount(
AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
auto it = allocation_map.find(device_memory.opaque());
if (it == allocation_map.end()) {
- allocation_map[device_memory.opaque()] = {device_memory, device_ordinal,
- /*ref_count=*/1};
+ allocation_map[device_memory.opaque()] = {
+ OwningDeviceMemory(device_memory, device_ordinal,
+ backend_->memory_allocator()),
+ /*ref_count=*/1};
} else {
it->second.ref_count++;
}
@@ -235,13 +237,12 @@ Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory,
Allocation& allocation = it->second;
TF_RET_CHECK(allocation.ref_count >= 1);
if (allocation.ref_count == 1) {
- TF_RETURN_IF_ERROR(backend_->memory_allocator()->Deallocate(
- device_ordinal, &device_memory));
+ allocation.device_memory.Free();
allocation_map.erase(it);
} else {
allocation.ref_count--;
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h
index 1174fa641c..a7d8927cf7 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.h
+++ b/tensorflow/compiler/xla/service/allocation_tracker.h
@@ -76,10 +76,7 @@ class AllocationTracker {
// Data structure encapsulating single memory allocation on the device.
struct Allocation {
// The pointer to this allocation.
- se::DeviceMemoryBase device_memory;
-
- // The device that the memory is allocated on.
- int device_ordinal;
+ OwningDeviceMemory device_memory;
// This is the number of times this memory allocation is referred to by
// registered data handles.
@@ -126,7 +123,10 @@ class AllocationTracker {
int64 next_handle_ GUARDED_BY(mutex_);
// A map from device ordinal to AllocationMap.
- tensorflow::gtl::FlatMap<int, AllocationMap> opaque_to_allocation_map_
+ //
+ // This is not a TF FlatMap because (currently) FlatMap (and therefore
+ // AllocationMap) is not movable.
+ std::unordered_map<int, AllocationMap> opaque_to_allocation_map_
GUARDED_BY(mutex_);
// A map from data handle to a vector of shaped buffers that represent the
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
new file mode 100644
index 0000000000..2099916509
--- /dev/null
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
@@ -0,0 +1,99 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
+
+namespace xla {
+StatusOr<bool>
+BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
+ HloInstruction* batch_dot) {
+ const DotDimensionNumbers& dim_numbers = batch_dot->dot_dimension_numbers();
+ HloInstruction *lhs = batch_dot->mutable_operand(0),
+ *rhs = batch_dot->mutable_operand(1);
+ const Shape& lhs_shape = lhs->shape();
+
+ std::vector<int64> degenerate_dims;
+ for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) {
+ if (lhs_shape.dimensions(batch_dim) == 1) {
+ degenerate_dims.push_back(batch_dim);
+ }
+ }
+
+ if (degenerate_dims.empty()) {
+ return false;
+ }
+
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs,
+ ElideDegenerateDims(lhs, degenerate_dims));
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs,
+ ElideDegenerateDims(rhs, degenerate_dims));
+
+ DotDimensionNumbers new_dim_numbers = dim_numbers;
+ new_dim_numbers.clear_lhs_batch_dimensions();
+ new_dim_numbers.clear_rhs_batch_dimensions();
+
+ for (int64 i = 0, e = dim_numbers.lhs_batch_dimensions_size() -
+ degenerate_dims.size();
+ i < e; i++) {
+ new_dim_numbers.add_lhs_batch_dimensions(i);
+ new_dim_numbers.add_rhs_batch_dimensions(i);
+ }
+
+ new_dim_numbers.set_lhs_contracting_dimensions(
+ 0,
+ new_dim_numbers.lhs_contracting_dimensions(0) - degenerate_dims.size());
+ new_dim_numbers.set_rhs_contracting_dimensions(
+ 0,
+ new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size());
+
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
+ MakeDotHlo(new_lhs, new_rhs, new_dim_numbers));
+
+ TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped,
+ MakeReshapeHlo(batch_dot->shape(), new_dot));
+
+ VLOG(2) << "Replaced " << batch_dot->ToString() << " with "
+ << new_dot->ToString();
+
+ TF_RETURN_IF_ERROR(
+ batch_dot->parent()->ReplaceInstruction(batch_dot, new_dot_reshaped));
+
+ return true;
+}
+
+tensorflow::StringPiece BatchDotSimplification::name() const {
+ return "batch-dot-simplification";
+}
+
+StatusOr<bool> BatchDotSimplification::Run(HloModule* module) {
+ bool changed = false;
+ std::vector<HloInstruction*> dot_instrs;
+ for (HloComputation* computation : module->MakeNonfusionComputations()) {
+ c_copy_if(computation->instructions(), std::back_inserter(dot_instrs),
+ [](HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kDot;
+ });
+ }
+ for (HloInstruction* dot_instr : dot_instrs) {
+ TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one,
+ ElideDegenerateBatchDimensionFromBatchDot(dot_instr));
+ changed |= elided_batch_dim_from_one;
+ }
+ return changed;
+}
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h
new file mode 100644
index 0000000000..c0ca8d8eba
--- /dev/null
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h
@@ -0,0 +1,39 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+// Simplifies batch dot operations.
+//
+// Normally these would live in the algebraic simplifier, but we want to run
+// this to fixpoint (this pass reaches fixed point in one execution) before we
+// run the DotDecomposer.
+class BatchDotSimplification : public HloPassInterface {
+ public:
+ StatusOr<bool> Run(HloModule* module) override;
+ tensorflow::StringPiece name() const override;
+
+ private:
+ StatusOr<bool> ElideDegenerateBatchDimensionFromBatchDot(
+ HloInstruction* batch_dot);
+};
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
new file mode 100644
index 0000000000..38f1a5d3a6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
@@ -0,0 +1,168 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+
+namespace xla {
+namespace {
+
+namespace op = xla::testing::opcode_matchers;
+
+class BatchDotSimplificationTest : public HloVerifiedTestBase {};
+
+TEST_F(BatchDotSimplificationTest,
+ ElideSingleDegenerateBatchDotDim_VectorVector) {
+ const string hlo_text = R"(
+HloModule BatchDot
+
+main {
+ a = f32[1,3] parameter(0)
+ b = f32[1,3] parameter(1)
+ ROOT dot = f32[1] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1}
+}
+)";
+
+ ParseAndVerifyModule(hlo_text);
+ BatchDotSimplification pass;
+ ASSERT_TRUE(pass.Run(&module()).ValueOrDie());
+
+ HloInstruction* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root,
+ op::Reshape(op::Dot(
+ op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
+ /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0)));
+}
+
+TEST_F(BatchDotSimplificationTest,
+ ElideSingleDegenerateBatchDotDim_MatrixVector) {
+ const string hlo_text = R"(
+HloModule BatchDot
+
+main {
+ a = f32[1,9,3] parameter(0)
+ b = f32[1,3] parameter(1)
+ ROOT dot = f32[1,9] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
+}
+)";
+
+ ParseAndVerifyModule(hlo_text);
+ BatchDotSimplification pass;
+ ASSERT_TRUE(pass.Run(&module()).ValueOrDie());
+
+ HloInstruction* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root,
+ op::Reshape(op::Dot(
+ op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
+ /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)));
+}
+
+TEST_F(BatchDotSimplificationTest,
+ ElideSingleDegenerateBatchDotDim_MatrixMatrix) {
+ const string hlo_text = R"(
+HloModule BatchDot
+
+main {
+ a = f32[1,9,3] parameter(0)
+ b = f32[1,3,7] parameter(1)
+ ROOT dot = f32[1,9,7] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
+}
+)";
+
+ ParseAndVerifyModule(hlo_text);
+ BatchDotSimplification pass;
+ ASSERT_TRUE(pass.Run(&module()).ValueOrDie());
+
+ HloInstruction* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root,
+ op::Reshape(op::Dot(
+ op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
+ /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)));
+}
+
+TEST_F(BatchDotSimplificationTest,
+ ElideMultipleDegenerateBatchDotDims_VectorVector) {
+ const string hlo_text = R"(
+HloModule BatchDot
+
+main {
+ a = f32[9,1,7,1,3] parameter(0)
+ b = f32[9,1,7,1,3] parameter(1)
+ ROOT dot = f32[9,1,7,1] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={4}
+}
+)";
+
+ ParseAndVerifyModule(hlo_text);
+ BatchDotSimplification pass;
+ ASSERT_TRUE(pass.Run(&module()).ValueOrDie());
+
+ HloInstruction* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root,
+ op::Reshape(op::Dot(
+ op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
+ /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/2)));
+}
+
+TEST_F(BatchDotSimplificationTest,
+ ElideMultipleDegenerateBatchDotDims_VectorMatrix) {
+ const string hlo_text = R"(
+HloModule BatchDot
+
+main {
+ a = f32[9,1,7,1,3] parameter(0)
+ b = f32[9,1,7,1,20,3] parameter(1)
+ ROOT dot = f32[9,1,7,1,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={5}
+}
+)";
+
+ ParseAndVerifyModule(hlo_text);
+ BatchDotSimplification pass;
+ ASSERT_TRUE(pass.Run(&module()).ValueOrDie());
+
+ HloInstruction* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root,
+ op::Reshape(op::Dot(
+ op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
+ /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/3)));
+}
+
+TEST_F(BatchDotSimplificationTest,
+ ElideMultipleDegenerateBatchDotDims_MatrixMatrix) {
+ const string hlo_text = R"(
+HloModule BatchDot
+
+main {
+ a = f32[9,1,7,1,19,3] parameter(0)
+ b = f32[9,1,7,1,3,20] parameter(1)
+ ROOT dot = f32[9,1,7,1,19,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={5}, rhs_contracting_dims={4}
+}
+)";
+
+ ParseAndVerifyModule(hlo_text);
+ BatchDotSimplification pass;
+ ASSERT_TRUE(pass.Run(&module()).ValueOrDie());
+
+ HloInstruction* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root,
+ op::Reshape(op::Dot(
+ op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)),
+ /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2)));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index 38086bd7e1..96e02b82b9 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -15,35 +15,32 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
-#include <algorithm>
#include <memory>
-#include <numeric>
-#include <set>
#include <string>
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
+namespace {
+
// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
// operations into smaller operations.
class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
@@ -80,17 +77,25 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
rewrite_grad_op_(rewrite_grad_op),
use_fusion_(use_fusion) {}
- HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type,
- HloOpcode opcode) {
- HloComputation::Builder b("scalar_computation");
- auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs"));
- auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs"));
- auto scalar_op = b.AddInstruction(
- HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}),
- opcode, scalar_lhs, scalar_rhs));
- return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
+ HloComputation* GetOrCreateScalarAddComputation(
+ PrimitiveType primitive_type) {
+ HloComputation** scalar_add_computation =
+ &scalar_add_computations_[primitive_type];
+ if (*scalar_add_computation) {
+ return *scalar_add_computation;
+ }
+
+ HloComputation::Builder b("scalar_add_computation");
+ Shape shape = ShapeUtil::MakeShape(primitive_type, {});
+ auto scalar_lhs = b.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
+ auto scalar_rhs = b.AddInstruction(
+ HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
+ auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
+ *scalar_add_computation =
+ computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
+ return *scalar_add_computation;
}
// Current HloComputation instance the BatchNormExpander is
@@ -105,6 +110,10 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
// Whether rewrite has occurred.
bool changed_ = false;
+ // Cached computations for adding two scalars.
+ tensorflow::gtl::FlatMap<PrimitiveType, HloComputation*>
+ scalar_add_computations_;
+
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
@@ -129,6 +138,8 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
}
};
+} // namespace
+
bool BatchNormExpanderVisitor::Run(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_inference_op,
@@ -199,7 +210,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
HloComputation* add_reduce_computation =
- GetScalarBinaryComputation(ptype, HloOpcode::kAdd);
+ GetOrCreateScalarAddComputation(ptype);
// X^2.
auto operand_squared = add(HloInstruction::CreateBinary(
@@ -500,7 +511,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
grad_output, activation_minus_mean));
HloComputation* add_reduce_computation =
- GetScalarBinaryComputation(ptype, HloOpcode::kAdd);
+ GetOrCreateScalarAddComputation(ptype);
// sum(Grad[Y] * (X - E[X])).
auto sum_grad_output_times_activiation_minus_mean =
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index 313910a861..5e1499ee6b 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -149,12 +149,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
EXPECT_TRUE(OutputsBF16(dot->operand(1)));
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
dot->operand(0)->literal(),
- *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)));
- LiteralTestUtil::ExpectEqual(
+ *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a))));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
dot->operand(1)->literal(),
- *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)));
+ *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b))));
}
// Tests that BF16 can be propagated through nested tuples.
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 94ccfedf62..c0b8bf9039 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/buffer_value_containers.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -699,7 +700,7 @@ BufferAssignmentProto BufferAssignment::ToProto() const {
BufferAssignmentProto::BufferAlias* proto_alias =
proto.add_buffer_aliases();
LogicalBufferProto::Location proto_alias_location =
- LogicalBuffer::ToLocationProto(*alias.instruction(), alias.index());
+ BufferValue::ToLocationProto(*alias.instruction(), alias.index());
proto_alias->set_source_buffer_id(buffer.id());
proto_alias->mutable_location()->Swap(&proto_alias_location);
}
@@ -1083,7 +1084,9 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
VLOG(2) << "Simulating heap for color " << color;
int64 alignment = assignment->color_alignment_(color);
HeapSimulator::Options options;
- options.buffers_to_assign = &single_colored_set.second;
+ BufferValueFlatSet buffer_value_set =
+ ToBufferValueFlatSet(single_colored_set.second);
+ options.buffers_to_assign = &buffer_value_set;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
@@ -1111,7 +1114,9 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
VLOG(2) << "Simulating heap for color " << color;
int64 alignment = assignment->color_alignment_(color);
HeapSimulator::Options options;
- options.buffers_to_assign = &single_colored_set.second;
+ BufferValueFlatSet buffer_value_set =
+ ToBufferValueFlatSet(single_colored_set.second);
+ options.buffers_to_assign = &buffer_value_set;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
@@ -1224,7 +1229,10 @@ void BufferAssigner::AssignBuffersFromHeapSimulator(
BufferAllocation* allocation = assignment->NewEmptyAllocation(
result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true, color);
for (const auto& buffer_chunk : result.chunk_map) {
- const LogicalBuffer& buffer = *buffer_chunk.first;
+ // TODO(lauj) Remove this down_cast after downstream users of
+ // BufferAllocation::assigned_buffers() are updated to use BufferValue.
+ const LogicalBuffer& buffer =
+ *CHECK_NOTNULL(dynamic_cast<const LogicalBuffer*>(buffer_chunk.first));
const HeapSimulator::Chunk& chunk = buffer_chunk.second;
assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size);
}
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc
index 37982aaef9..acb546a0a1 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness.cc
@@ -44,7 +44,7 @@ StatusOr<std::unique_ptr<BufferLiveness>> BufferLiveness::Run(
return std::move(liveness);
}
-tensorflow::Status BufferLiveness::Analyze() {
+Status BufferLiveness::Analyze() {
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_));
for (auto* computation : module_->computations()) {
if (computation->IsFusionComputation()) {
@@ -71,7 +71,7 @@ tensorflow::Status BufferLiveness::Analyze() {
}
XLA_VLOG_LINES(3, ToString());
- return tensorflow::Status::OK();
+ return Status::OK();
}
string BufferLiveness::ToString() const {
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h
index 11834a5127..cdd3cf4032 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness.h
+++ b/tensorflow/compiler/xla/service/buffer_liveness.h
@@ -89,7 +89,7 @@ class BufferLiveness {
// Perform buffer liveness analysis. This method must be called prior to
// MayInterfere or MaybeLiveOut.
- tensorflow::Status Analyze();
+ Status Analyze();
// Returns true if the live range of the buffer of 'a' is strictly before the
// live range of the buffer of 'b' (they do not overlap).
diff --git a/tensorflow/compiler/xla/service/buffer_value_containers.h b/tensorflow/compiler/xla/service/buffer_value_containers.h
new file mode 100644
index 0000000000..305914fca8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/buffer_value_containers.h
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_
+
+#include "tensorflow/compiler/xla/service/buffer_value.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/core/lib/gtl/compactptrset.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+
+namespace xla {
+
+// Define various containers of BufferValues, and utilities to convert from
+// containers of LogicalBuffers to containers of BufferValues.
+
+using BufferValueCompactPointerSet =
+ tensorflow::gtl::CompactPointerSet<const BufferValue*>;
+template <class LogicalBufferContainerT>
+BufferValueCompactPointerSet ToBufferValueCompactPointerSet(
+ const LogicalBufferContainerT& logical_buffer_container) {
+ BufferValueCompactPointerSet output;
+ for (const LogicalBuffer* buffer : logical_buffer_container) {
+ output.insert(buffer);
+ }
+ return output;
+}
+
+using BufferValueFlatSet = tensorflow::gtl::FlatSet<const BufferValue*>;
+template <class LogicalBufferContainerT>
+BufferValueFlatSet ToBufferValueFlatSet(
+ const LogicalBufferContainerT& logical_buffer_container) {
+ BufferValueFlatSet output;
+ output.reserve(logical_buffer_container.size());
+ for (const LogicalBuffer* buffer : logical_buffer_container) {
+ output.insert(buffer);
+ }
+ return output;
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_
diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h
index c10609e67f..7f2ce0e897 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.h
+++ b/tensorflow/compiler/xla/service/compile_only_service.h
@@ -75,48 +75,42 @@ class CompileOnlyService : public Service {
// Override Service methods that require or imply the existence of an
// execute backend. Note that this does not include TransferToClient, as
// computing constants produces global data that we may wish to transfer.
- tensorflow::Status Execute(const ExecuteRequest* arg,
- ExecuteResponse* result) override {
+ Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");
}
- tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg,
- ExecuteParallelResponse* result) override {
+ Status ExecuteParallel(const ExecuteParallelRequest* arg,
+ ExecuteParallelResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");
}
- tensorflow::Status GetDeviceHandles(
- const GetDeviceHandlesRequest* arg,
- GetDeviceHandlesResponse* result) override {
+ Status GetDeviceHandles(const GetDeviceHandlesRequest* arg,
+ GetDeviceHandlesResponse* result) override {
return Unimplemented("CompileOnlyService does not support devices.");
}
- tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg,
- ExecuteAsyncResponse* result) override {
+ Status ExecuteAsync(const ExecuteAsyncRequest* arg,
+ ExecuteAsyncResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");
}
- tensorflow::Status WaitForExecution(
- const WaitForExecutionRequest* arg,
- WaitForExecutionResponse* result) override {
+ Status WaitForExecution(const WaitForExecutionRequest* arg,
+ WaitForExecutionResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");
}
- tensorflow::Status TransferToServer(
- const TransferToServerRequest* arg,
- TransferToServerResponse* result) override {
+ Status TransferToServer(const TransferToServerRequest* arg,
+ TransferToServerResponse* result) override {
return Unimplemented(
"CompileOnlyService does not support device data transfers.");
}
- tensorflow::Status TransferToInfeed(
- const TransferToInfeedRequest* arg,
- TransferToInfeedResponse* result) override {
+ Status TransferToInfeed(const TransferToInfeedRequest* arg,
+ TransferToInfeedResponse* result) override {
return Unimplemented(
"CompileOnlyService does not support device data transfers.");
}
- tensorflow::Status TransferFromOutfeed(
- const TransferFromOutfeedRequest* arg,
- TransferFromOutfeedResponse* result) override {
+ Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
+ TransferFromOutfeedResponse* result) override {
return Unimplemented(
"CompileOnlyService does not support device data transfers.");
}
- tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
- ResetDeviceResponse* result) override {
+ Status ResetDevice(const ResetDeviceRequest* arg,
+ ResetDeviceResponse* result) override {
return Unimplemented("CompileOnlyService does not support devices.");
}
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 7e6d58c7fa..5f5b81686a 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -103,6 +103,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:algebraic_simplifier",
+ "//tensorflow/compiler/xla/service:batch_dot_simplification",
"//tensorflow/compiler/xla/service:batchnorm_expander",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:buffer_liveness",
@@ -296,6 +297,15 @@ cc_library(
)
cc_library(
+ name = "target_machine_features_fake",
+ testonly = 1,
+ hdrs = ["target_machine_features_fake.h"],
+ deps = [
+ ":target_machine_features",
+ ],
+)
+
+cc_library(
name = "ir_function",
srcs = ["ir_function.cc"],
hdrs = ["ir_function.h"],
@@ -336,6 +346,7 @@ cc_library(
deps = [
":cpu_options",
":cpu_runtime",
+ ":ir_emission_utils",
":target_machine_features",
":vector_support_library",
"//tensorflow/compiler/xla:shape_util",
@@ -660,6 +671,7 @@ cc_library(
hdrs = ["ir_emission_utils.h"],
deps = [
":cpu_runtime",
+ ":target_machine_features",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla/service:hlo",
@@ -672,6 +684,7 @@ tf_cc_test(
srcs = ["ir_emission_utils_test.cc"],
deps = [
":ir_emission_utils",
+ ":target_machine_features_fake",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
@@ -690,6 +703,7 @@ cc_library(
deps = [
":dot_op_emitter",
":ir_emission_utils",
+ ":target_machine_features",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:layout_assignment",
@@ -703,6 +717,7 @@ tf_cc_test(
srcs = ["cpu_layout_assignment_test.cc"],
deps = [
":cpu_layout_assignment",
+ ":target_machine_features_fake",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
@@ -727,6 +742,7 @@ cc_library(
deps = [
":cpu_runtime",
":ir_emission_utils",
+ ":target_machine_features",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -741,6 +757,7 @@ tf_cc_test(
srcs = ["conv_canonicalization_test.cc"],
deps = [
":conv_canonicalization",
+ ":target_machine_features_fake",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
@@ -779,6 +796,7 @@ cc_library(
":dot_op_emitter",
":ir_emission_utils",
":shape_partition",
+ ":target_machine_features",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_pass",
@@ -791,6 +809,7 @@ tf_cc_test(
deps = [
":cpu_executable",
":parallel_task_assignment",
+ ":target_machine_features_fake",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
@@ -913,3 +932,17 @@ tf_cc_test(
"//tensorflow/core:test",
],
)
+
+tf_cc_test(
+ name = "cpu_eigen_tensor_alignment_test",
+ size = "small",
+ srcs = ["cpu_eigen_tensor_alignment_test.cc"],
+ deps = [
+ ":dot_op_emitter",
+ ":ir_emission_utils",
+ ":target_machine_features_fake",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
index 2136aeb387..0985b9297f 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
@@ -33,7 +33,8 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
for (HloInstruction* hlo :
module->entry_computation()->MakeInstructionPostOrder()) {
if (hlo->opcode() == HloOpcode::kConvolution &&
- !PotentiallyImplementedAsEigenConvolution(*hlo)) {
+ !PotentiallyImplementedAsEigenConvolution(*hlo,
+ target_machine_features_)) {
const ConvolutionDimensionNumbers& dnums =
hlo->convolution_dimension_numbers();
auto input_batch_dim = dnums.input_batch_dimension();
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
index 9b2c3d82eb..e6fd1499ed 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@@ -32,12 +33,19 @@ namespace cpu {
// convolutions can run faster.
class ConvCanonicalization : public HloPassInterface {
public:
+ explicit ConvCanonicalization(
+ const TargetMachineFeatures* target_machine_features)
+ : target_machine_features_(*target_machine_features) {}
+
~ConvCanonicalization() override {}
tensorflow::StringPiece name() const override {
return "convolution-canonicalization";
}
StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ const TargetMachineFeatures& target_machine_features_;
};
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
index 968f53d5c7..375b017b09 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -89,7 +90,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build());
- ConvCanonicalization conv_canonicalization;
+ cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features(
+ [](int64 shape_size) {
+ return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
+ });
+ ConvCanonicalization conv_canonicalization(&target_machine_features);
EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie());
const HloInstruction* output_reshape = entry_computation->root_instruction();
@@ -146,7 +151,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- ConvCanonicalization conv_canonicalization;
+ cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features(
+ [](int64 shape_size) {
+ return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
+ });
+ ConvCanonicalization conv_canonicalization(&target_machine_features);
EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie());
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 3d2e24ca14..beeb826747 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
+#include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
@@ -231,7 +232,10 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
};
} // namespace
-Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
+Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
+ llvm::TargetMachine* target_machine) {
+ LLVMTargetMachineFeatures target_machine_features(target_machine);
+
// Optimization pipeline.
HloPassPipeline pipeline("CPU");
pipeline.AddInvariantChecker<HloVerifier>();
@@ -248,8 +252,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
// TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner
// pass.
pipeline.AddPass<CallInliner>();
+ pipeline.AddPass<BatchDotSimplification>();
pipeline.AddPass<DotDecomposer>();
- pipeline.AddPass<ConvCanonicalization>();
+ pipeline.AddPass<ConvCanonicalization>(&target_machine_features);
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
@@ -279,9 +284,10 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
pass.AddPass<ConditionalSimplifier>();
}
pipeline.AddPass<TransposeFolding>(
- [](const HloInstruction& dot,
- const TransposeFolding::OperandIndices& candidate_operands) {
- return PotentiallyImplementedAsEigenDot(dot)
+ [&target_machine_features](
+ const HloInstruction& dot,
+ const TransposeFolding::OperandIndices& candidate_operands) {
+ return PotentiallyImplementedAsEigenDot(dot, target_machine_features)
? candidate_operands
: TransposeFolding::OperandIndices{};
},
@@ -296,7 +302,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
pipeline.AddPass<CpuLayoutAssignment>(
- module->device_entry_computation_layout());
+ module->device_entry_computation_layout(), &target_machine_features);
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
@@ -316,8 +322,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
// and thread synchronization dependencies which would likely increase
// binary size (and most AOT applications are single-threaded).
// TODO(b/29630486) Support multi-threaded AOT.
- pipeline.AddPass<ParallelTaskAssigner>(max_parallelism,
- ShapeSizeBytesFunction());
+ pipeline.AddPass<ParallelTaskAssigner>(
+ max_parallelism, ShapeSizeBytesFunction(), &target_machine_features);
}
// Copy insertion should be performed immediately before IR emission to avoid
// inserting unnecessary copies (later pass adds an instruction which
@@ -470,7 +476,13 @@ StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
VLOG(2) << "Before optimization:";
XLA_VLOG_LINES(2, module->ToString());
- TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false));
+ std::unique_ptr<llvm::TargetMachine> jit_target_machine =
+ SimpleOrcJIT::InferTargetMachineForJIT(
+ CompilerTargetOptions(module->config()),
+ CodeGenOptLevel(module->config()));
+
+ TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false,
+ jit_target_machine.get()));
VLOG(2) << "After optimization:";
XLA_VLOG_LINES(2, module->ToString());
@@ -561,10 +573,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// GetEmbeddedComputations guarantees that a called computation occurs
// before a caller computation.
+ LLVMTargetMachineFeatures target_machine_features(jit->target_machine());
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
std::move(instruction_to_profile_idx),
std::move(computation_to_profile_idx),
- jit->target_machine(), jit->external_constant_pool());
+ &target_machine_features, jit->external_constant_pool());
for (auto embedded_computation :
entry_computation->MakeEmbeddedComputationsList()) {
@@ -706,7 +719,8 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
VLOG(2) << "Before optimization:";
XLA_VLOG_LINES(2, module->ToString());
- TF_RETURN_IF_ERROR(RunHloPasses(module, /*is_aot_compile=*/true));
+ TF_RETURN_IF_ERROR(
+ RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get()));
VLOG(2) << "After optimization:";
XLA_VLOG_LINES(2, module->ToString());
@@ -746,10 +760,11 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
&hlo_profile_index_map, &hlo_profile_printer_data));
}
+ LLVMTargetMachineFeatures target_machine_features(target_machine.get());
IrEmitter ir_emitter(*module, *assignment, &llvm_module,
std::move(instruction_to_profile_idx),
std::move(computation_to_profile_idx),
- target_machine.get(),
+ &target_machine_features,
/*external_constant_pool=*/nullptr);
HloComputation* computation = module->entry_computation();
for (auto embedded_computation :
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
index 65b05f04fa..e56f9f0113 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
+#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
@@ -148,7 +149,8 @@ class CpuCompiler : public LLVMCompiler {
// Runs the HLO passes which are necessary for both optimizations and
// correctness.
- Status RunHloPasses(HloModule* module, bool is_aot_compile);
+ Status RunHloPasses(HloModule* module, bool is_aot_compile,
+ llvm::TargetMachine* target_machine);
TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler);
};
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc
new file mode 100644
index 0000000000..d12fa6bb9a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc
@@ -0,0 +1,94 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+
+// Test that we don't call into Eigen with tensors too small to be aligned
+// reliably.
+
+class CpuEigenTensorAlignmentTest : public ::testing::Test {};
+
+TEST_F(CpuEigenTensorAlignmentTest, EigenDotAlignment) {
+ string hlo_string = R"(
+HloModule DotOperation
+
+ENTRY DotOperation {
+ arg0 = f32[5,256] parameter(0)
+ arg1 = f32[256,1024] parameter(1)
+ ROOT dot = f32[5,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
+
+ HloInstruction* dot = module->entry_computation()->root_instruction();
+
+ TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment(
+ [](int64 size) { return 1; });
+
+ EXPECT_FALSE(
+ PotentiallyImplementedAsEigenDot(*dot, target_machine_with_no_alignment));
+
+ TargetMachineFeaturesWithFakeAlignmentLogic
+ target_machine_with_full_alignment([](int64 size) {
+ return TargetMachineFeatures::kEigenExpectedTensorAlignment;
+ });
+
+ EXPECT_TRUE(PotentiallyImplementedAsEigenDot(
+ *dot, target_machine_with_full_alignment));
+}
+
+TEST_F(CpuEigenTensorAlignmentTest, EigenConvAlignment) {
+ string hlo_string = R"(
+HloModule ConvOperation
+
+ENTRY ConvOperation {
+ arg0 = f32[1,2,1] parameter(0)
+ arg1 = f32[1,1,1] parameter(1)
+ ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1}, dim_labels=b0f_0io->b0f
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
+
+ HloInstruction* conv = module->entry_computation()->root_instruction();
+
+ TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment(
+ [](int64 size) { return 1; });
+
+ EXPECT_FALSE(PotentiallyImplementedAsEigenConvolution(
+ *conv, target_machine_with_no_alignment));
+
+ TargetMachineFeaturesWithFakeAlignmentLogic
+ target_machine_with_full_alignment([](int64 size) {
+ return TargetMachineFeatures::kEigenExpectedTensorAlignment;
+ });
+
+ EXPECT_TRUE(PotentiallyImplementedAsEigenConvolution(
+ *conv, target_machine_with_full_alignment));
+}
+} // namespace
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 32613b8690..cf43b74c69 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -73,7 +73,7 @@ CpuExecutable::CpuExecutable(
Status CpuExecutable::AllocateBuffers(
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
- std::vector<se::DeviceMemoryBase>* buffers) {
+ std::vector<OwningDeviceMemory>* buffers) {
CHECK_EQ(buffers->size(), assignment_->Allocations().size());
VLOG(3) << "Allocating " << assignment_->Allocations().size()
<< " allocations for module " << module().name();
@@ -201,60 +201,18 @@ Status CpuExecutable::ExecuteComputeFunction(
return Status::OK();
}
-static void LogLiveAddresses(
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
- const std::vector<bool>& buffers_in_result) {
- if (!VLOG_IS_ON(3)) {
- return;
- }
-
- CHECK_EQ(buffers.size(), buffers_in_result.size());
- std::vector<const void*> live_out_buffers;
- for (int i = 0; i < buffers.size(); ++i) {
- if (buffers_in_result[i]) {
- live_out_buffers.push_back(buffers[i].opaque());
- }
- }
- VLOG(3) << "Live addresses in output marking found "
- << live_out_buffers.size() << " addresses:\n"
- << tensorflow::str_util::Join(
- live_out_buffers, ", ", [](string* out, const void* address) {
- tensorflow::strings::StrAppend(
- out, tensorflow::strings::Printf("%p", address));
- });
-}
-
-static Status DeallocateTempBuffers(
- DeviceMemoryAllocator* allocator, se::Stream* stream,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
- const std::vector<bool>& buffers_in_result) {
- // Keep those buffers in the output of the marked live because they are needed
- // by the service. They will be deallocated by the service.
- for (size_t i = 0; i < buffers.size(); ++i) {
- se::DeviceMemoryBase alloc = buffers[i];
- if (!buffers_in_result[i] && !alloc.is_null()) {
- VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
- << alloc.opaque() << "]";
- TF_RETURN_IF_ERROR(
- allocator->Deallocate(stream->parent()->device_ordinal(), &alloc));
- }
- }
-
- return Status::OK();
-}
-
StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> allocated_buffers,
- std::vector<bool>* buffers_in_result) {
+ tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers) {
se::Stream* stream = run_options->stream();
ScopedShapedBuffer result_buffer(
/*on_host_shape=*/host_result_shape(),
/*on_device_shape=*/host_result_shape(), run_options->allocator(),
stream->parent()->device_ordinal());
- // Copy DeviceMemoryBase values which contain the array(s) of the result into
- // the respective location in ShapedBuffer which is returned to the caller.
+ // Move OwningDeviceMemory values which contain the array(s) of the result
+ // into the respective location in ScopedShapedBuffer which is returned to the
+ // caller.
TF_RETURN_IF_ERROR(result_buffer.buffers().ForEachMutableElementWithStatus(
[&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) {
const auto& sources = this->GetRootPointsToSet().element(index);
@@ -273,10 +231,9 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
CHECK(!slice.allocation()->is_entry_computation_parameter());
const BufferAllocation::Index buffer_index = slice.index();
- const se::DeviceMemoryBase& buffer = allocated_buffers[buffer_index];
+ OwningDeviceMemory& buffer = buffers[buffer_index];
CHECK(!buffer.is_null() || buffer.size() == 0);
- *device_memory = buffer;
- (*buffers_in_result)[buffer_index] = true;
+ *device_memory = buffer.Forget();
return Status::OK();
}));
return std::move(result_buffer);
@@ -292,23 +249,21 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
- std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
+ std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size());
TF_RETURN_IF_ERROR(AllocateBuffers(
memory_allocator, stream->parent()->device_ordinal(), &buffers));
- TF_RETURN_IF_ERROR(ExecuteComputeFunction(
- &run_options->run_options(), arguments, buffers, hlo_execution_profile));
- std::vector<bool> buffers_in_result(assignment_->Allocations().size(), false);
- TF_ASSIGN_OR_RETURN(
- ScopedShapedBuffer result_buffer,
- CreateResultShapedBuffer(run_options, buffers, &buffers_in_result));
-
- // Free all buffers not in the result.
- TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers,
- buffers_in_result));
+ std::vector<se::DeviceMemoryBase> unowning_buffers;
+ unowning_buffers.reserve(buffers.size());
+ for (auto& buffer : buffers) {
+ unowning_buffers.push_back(buffer.AsDeviceMemoryBase());
+ }
+ TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(),
+ arguments, unowning_buffers,
+ hlo_execution_profile));
- return std::move(result_buffer);
+ return CreateResultShapedBuffer(run_options, &buffers);
}
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
@@ -324,30 +279,53 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
run_options->stream()->implementation());
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
- std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
-
+ std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size());
TF_RETURN_IF_ERROR(AllocateBuffers(
memory_allocator, stream->parent()->device_ordinal(), &buffers));
- std::vector<bool> buffers_in_result(assignment_->Allocations().size(), false);
- TF_ASSIGN_OR_RETURN(
- ScopedShapedBuffer result_buffer,
- CreateResultShapedBuffer(run_options, buffers, &buffers_in_result));
-
- LogLiveAddresses(buffers, buffers_in_result);
-
- host_stream->EnqueueTask([this, run_options, arguments, buffers,
- buffers_in_result, memory_allocator, stream]() {
- // Failing a CHECK here is not great, but I don't see an obvious way to
- // return a failed Status asynchronously.
- TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments,
- buffers,
- /*hlo_execution_profile=*/nullptr));
- TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers,
- buffers_in_result));
- });
+ std::vector<se::DeviceMemoryBase> unowning_buffers;
+ unowning_buffers.reserve(buffers.size());
+ for (auto& buffer : buffers) {
+ unowning_buffers.push_back(buffer.AsDeviceMemoryBase());
+ }
+ TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
+ CreateResultShapedBuffer(run_options, &buffers));
- return std::move(result_buffer);
+ // At this point, `unowning_buffers` contains unowning pointers to all of our
+ // buffers, and `buffers` contains owning pointers to the non-live-out
+ // buffers. Enqueue a task which keeps alive the non-live-out buffers.
+ //
+ // Logically we want this lambda to capture `buffers` by move, ultimately our
+ // functor needs to be wrapped in an std::function, and that requires its
+ // functor to be copyable. Thus we perpitrate the hack of capturing buffers
+ // "by shared pointer".
+ //
+ // We also need to change the types of some of the variables we capture:
+ // run_options needs to change from a pointer to a value type, and arguments
+ // needs to change from an ArraySlice into a vector. We use a struct instead
+ // of a lambda to make this explicit.
+ struct AsyncRunTask {
+ CpuExecutable* executable;
+ ServiceExecutableRunOptions run_options;
+ std::vector<const ShapedBuffer*> arguments;
+ std::vector<se::DeviceMemoryBase> unowning_buffers;
+ std::shared_ptr<std::vector<OwningDeviceMemory>> buffers;
+
+ void operator()() {
+ // Failing a CHECK here is not great, but I don't see an obvious way to
+ // return a failed Status asynchronously.
+ TF_CHECK_OK(executable->ExecuteComputeFunction(
+ &run_options.run_options(), arguments, unowning_buffers,
+ /*hlo_execution_profile=*/nullptr));
+ }
+ };
+ host_stream->EnqueueTask(AsyncRunTask{
+ this, *run_options,
+ std::vector<const ShapedBuffer*>(arguments.begin(), arguments.end()),
+ unowning_buffers,
+ std::make_shared<std::vector<OwningDeviceMemory>>(std::move(buffers))});
+
+ return std::move(result);
}
/*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 68ad38cba8..8dd47bfb86 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -92,7 +92,7 @@ class CpuExecutable : public Executable {
// buffer is assigned for this element.
Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator,
int device_ordinal,
- std::vector<se::DeviceMemoryBase>* buffers);
+ std::vector<OwningDeviceMemory>* buffers);
// Calls the generated function performing the computation with the given
// arguments using the supplied buffers.
@@ -102,16 +102,12 @@ class CpuExecutable : public Executable {
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile);
- // Creates a ScopedShapedBuffer for holding the result of the computation. The
- // addresses (DeviceMemoryBases) are set according to buffer assignment.
- // 'buffers_in_result' should point to a vector of the same size as
- // 'allocated_buffers'. An element in buffers_in_result is set to true if the
- // corresponding buffer is live out of the computation (and thus contained in
- // the returned ShapedBuffer).
+ // Creates a ScopedShapedBuffer for holding the result of the computation,
+ // moving buffers out of allocated_buffers and into the result as appropriate.
+ // The addresses are set according to buffer assignment.
StatusOr<ScopedShapedBuffer> CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> allocated_buffers,
- std::vector<bool>* buffers_in_result);
+ tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers);
// Returns the points-to set of the root instruction of the entry
// computation. Uses points-to analysis from buffer assignment.
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
index 6c642080c3..aa872d5ec9 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
@@ -100,7 +100,8 @@ Status CpuLayoutAssignment::AddBackendConstraints(
const HloComputation* computation = constraints->computation();
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kConvolution &&
- PotentiallyImplementedAsEigenConvolution(*instruction)) {
+ PotentiallyImplementedAsEigenConvolution(*instruction,
+ target_machine_features_)) {
const HloInstruction* convolution = instruction;
const HloInstruction* lhs_instruction = convolution->operand(0);
const HloInstruction* rhs_instruction = convolution->operand(1);
@@ -126,7 +127,8 @@ Status CpuLayoutAssignment::AddBackendConstraints(
const HloInstruction* op = instruction->operand(*op_idx);
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
ColMajorShape(op->shape()), instruction, *op_idx));
- } else if (PotentiallyImplementedAsEigenDot(*instruction)) {
+ } else if (PotentiallyImplementedAsEigenDot(*instruction,
+ target_machine_features_)) {
const HloInstruction* dot = instruction;
// In order to implement `dot` with Eigen dot, the layouts of the lhs,
// rhs, and output need to be row-major.
@@ -177,7 +179,7 @@ Status CpuLayoutAssignment::AddBackendConstraints(
}
}
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
index 09adb5cb02..53536a277c 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_
#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/core/lib/core/status.h"
@@ -28,12 +29,16 @@ namespace cpu {
class CpuLayoutAssignment : public LayoutAssignment {
public:
explicit CpuLayoutAssignment(
- const ComputationLayout& entry_computation_layout)
- : LayoutAssignment(entry_computation_layout) {}
+ const ComputationLayout& entry_computation_layout,
+ const TargetMachineFeatures* target_machine_features)
+ : LayoutAssignment(entry_computation_layout),
+ target_machine_features_(*target_machine_features) {}
~CpuLayoutAssignment() override {}
protected:
Status AddBackendConstraints(LayoutConstraints* constraints) override;
+
+ const TargetMachineFeatures& target_machine_features_;
};
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index ba4c5a23d3..f6c93d36f7 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -49,7 +50,12 @@ class CpuLayoutAssignmentTest : public HloTestBase {
protected:
void AssignLayouts(HloModule* module,
ComputationLayout* entry_computation_layout) {
- cpu::CpuLayoutAssignment layout_assignment(*entry_computation_layout);
+ cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features(
+ [](int64 shape_size) {
+ return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
+ });
+ cpu::CpuLayoutAssignment layout_assignment(*entry_computation_layout,
+ &target_machine_features);
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
};
@@ -311,7 +317,12 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
result.addend_fusion_param = fusion_instruction->operand(
fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number());
- cpu::CpuLayoutAssignment layout_assignment(computation_layout);
+ cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features(
+ [](int64 shape_size) {
+ return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
+ });
+ cpu::CpuLayoutAssignment layout_assignment(computation_layout,
+ &target_machine_features);
TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something,
layout_assignment.Run(module));
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index 9b39e7f576..d97802ee45 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -88,8 +88,8 @@ CpuTransferManager::CpuTransferManager()
: GenericTransferManager(se::host::kHostPlatformId,
/*pointer_size=*/sizeof(void*)) {}
-Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor,
- const Literal& literal) {
+Status CpuTransferManager::TransferLiteralToInfeed(
+ se::StreamExecutor* executor, const LiteralSlice& literal) {
const Shape& shape = literal.shape();
VLOG(2) << "Transferring literal to infeed with shape: "
<< ShapeUtil::HumanString(shape);
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
index 3ecb0d2364..6dfc666f09 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
@@ -38,7 +38,7 @@ class CpuTransferManager : public GenericTransferManager {
~CpuTransferManager() override {}
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
- const Literal& literal) override;
+ const LiteralSlice& literal) override;
Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
const void* source) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 8db4a0650d..5cdfc110af 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -541,7 +542,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot,
hlo_module_config_(hlo_module_config),
target_machine_features_(target_machine_features) {}
-/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation(
+/* static */ Status DotOpEmitter::EmitDotOperation(
const HloInstruction& dot, const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
const llvm_ir::IrArray* addend_array,
@@ -690,7 +691,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
return true;
}
-tensorflow::Status DotOpEmitter::Emit() {
+Status DotOpEmitter::Emit() {
// The dot operation performs a sum of products over dimension 0 of the left
// hand side operand and dimension 1 of the right hand side operand.
//
@@ -734,7 +735,7 @@ tensorflow::Status DotOpEmitter::Emit() {
CHECK_EQ(addend_array_, nullptr);
- if (PotentiallyImplementedAsEigenDot(dot_)) {
+ if (PotentiallyImplementedAsEigenDot(dot_, target_machine_features_)) {
return EmitCallToRuntime();
}
@@ -868,10 +869,10 @@ tensorflow::Status DotOpEmitter::Emit() {
// loop.
ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status DotOpEmitter::EmitScalarDot() {
+Status DotOpEmitter::EmitScalarDot() {
// A scalar dot is just a scalar multiply.
llvm::Value* result;
llvm::Value* lhs_value =
@@ -896,10 +897,10 @@ tensorflow::Status DotOpEmitter::EmitScalarDot() {
result = ir_builder_->CreateFMul(lhs_value, rhs_value);
}
target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_);
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
+Status DotOpEmitter::EmitCallToRuntime() {
// The signature of the Eigen runtime matmul function is:
//
// (void)(void* run_options, float* out, float* lhs, float* rhs,
@@ -1001,7 +1002,7 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
ir_builder_->getInt64(mat_mult_dims.k),
ir_builder_->getInt32(transpose_lhs),
ir_builder_->getInt32(transpose_rhs)});
- return tensorflow::Status::OK();
+ return Status::OK();
}
DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
@@ -1058,19 +1059,39 @@ static bool IsRank2WithNoPadding(const Shape& shape) {
// In a gemm operation where output = lhs * rhs, check whether the given shapes
// are valid for the operation.
-static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
- const Shape& output_shape) {
+static bool AreValidGemmShapes(
+ const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape,
+ const TargetMachineFeatures& target_machine_features) {
// The inputs and the output must
// 1) be matrices with no padding, and
// 2) have an allowed element type.
PrimitiveType output_primitive_type = output_shape.element_type();
- return (output_primitive_type == F64 || output_primitive_type == F32 ||
- output_primitive_type == F16) &&
- IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) &&
- IsRank2WithNoPadding(output_shape);
+ if (!(output_primitive_type == F64 || output_primitive_type == F32 ||
+ output_primitive_type == F16)) {
+ return false;
+ }
+
+ if (!(IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) &&
+ IsRank2WithNoPadding(output_shape))) {
+ return false;
+ }
+
+ auto is_aligned = [&](const Shape& shape) {
+ return GetMinimumAlignmentForArray(shape, target_machine_features) >=
+ TargetMachineFeatures::kEigenExpectedTensorAlignment;
+ };
+
+ if (!is_aligned(lhs_shape) || !is_aligned(rhs_shape) ||
+ !is_aligned(output_shape)) {
+ return false;
+ }
+
+ return true;
}
-bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
+bool PotentiallyImplementedAsEigenDot(
+ const HloInstruction& hlo,
+ const TargetMachineFeatures& target_machine_features) {
// For certain types of Dot, we can call Eigen
if (hlo.opcode() == HloOpcode::kDot) {
const Shape& lhs_shape = hlo.operand(0)->shape();
@@ -1087,7 +1108,8 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
// If gemm can accept the operand shapes, use it rather than a custom
// kernel.
- if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
+ if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape(),
+ target_machine_features)) {
const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers();
// The size of the reduction dimension should match. The shape inference
// guarantees this invariant, so the check here is for programming
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index a20bf2f9db..566f07ba75 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -31,7 +31,9 @@ limitations under the License.
namespace xla {
namespace cpu {
-bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo);
+bool PotentiallyImplementedAsEigenDot(
+ const HloInstruction& hlo,
+ const TargetMachineFeatures& target_machine_features);
// Returns the index for an operand to `hlo` that should ideally be column
// major. Returns nullopt if there is no such operand or if `hlo` is not a dot
@@ -55,7 +57,7 @@ class DotOpEmitter {
// dimensions as the result, and the result is computed as `addend_array` +
// dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported
// for Matrix-vector products.
- static tensorflow::Status EmitDotOperation(
+ static Status EmitDotOperation(
const HloInstruction& dot, const llvm_ir::IrArray& target_array,
const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
const llvm_ir::IrArray* addend_array,
@@ -74,18 +76,18 @@ class DotOpEmitter {
const TargetMachineFeatures& target_machine_features);
// Emits the IR to perform the dot operation.
- tensorflow::Status Emit();
+ Status Emit();
// Emits instructions to perform a scalar dot product (a multiply of the
// LHS and RHS) and store the results in the target.
- tensorflow::Status EmitScalarDot();
+ Status EmitScalarDot();
// Emit an LLVM IR implementation of the dot operation if we can. Returns
// true if an LLVM IR implementation was emitted.
bool EmitLlvmIrDotIfProfitable();
// Emits a call to the CPU runtime to perform the matrix multiply.
- tensorflow::Status EmitCallToRuntime();
+ Status EmitCallToRuntime();
// Emits a series of nested loops for iterating over an operand array in the
// dot operation. Loops are constructed in major to minor dimension layout
diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc
index 7dcc4ca7fa..c562865591 100644
--- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc
+++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc
@@ -26,13 +26,13 @@ limitations under the License.
namespace xla {
namespace cpu {
-void ExternalConstantPool::Insert(string name, const Literal& literal,
+void ExternalConstantPool::Insert(string name, const LiteralSlice& literal,
int64 alignment) {
CHECK(!ShapeUtil::IsTuple(literal.shape()));
CHECK(alignment > 0 && IsPowerOfTwo(static_cast<uint64>(alignment)));
CHECK(entries_.find(name) == entries_.end());
- int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape());
+ const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape());
void* raw_pointer = tensorflow::port::AlignedMalloc(
literal_size, std::max<size_t>(alignment, sizeof(void*)));
CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size
diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h
index 8008a56df4..0677f5f0b5 100644
--- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h
+++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h
@@ -43,7 +43,7 @@ class ExternalConstantPool {
// The constant pool copies out the contents of `literal` into a buffer it
// owns -- it does not keep pointers to `literal`, or to memory owned by
// `literal`.
- void Insert(string name, const Literal& literal, int64 alignment);
+ void Insert(string name, const LiteralSlice& literal, int64 alignment);
// Find the constant with name `name` in this constant pool. If there isn't
// such constant, return nullptr.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
index f209a69e3c..b560b7531c 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
@@ -24,8 +24,25 @@ limitations under the License.
namespace xla {
namespace cpu {
+int64 GetMinimumAlignmentForArray(
+ const Shape& shape, const TargetMachineFeatures& target_machine_features) {
+ CHECK(ShapeUtil::IsArray(shape));
+ CHECK(!LayoutUtil::HasLayout(shape) || LayoutUtil::IsDense(shape.layout()));
+
+ // We don't require a layout to be set on `shape`. This only works on CPU
+ // because we don't pad our tensors or otherwise have complicated data tiling
+ // schemes.
+
+ int64 allocation_size_bytes =
+ ShapeUtil::ElementsIn(shape) *
+ ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
+ return target_machine_features.minimum_alignment_for_allocation(
+ allocation_size_bytes);
+}
+
bool PotentiallyImplementedAsEigenConvolution(
- const HloInstruction& convolution) {
+ const HloInstruction& convolution,
+ const TargetMachineFeatures& target_machine_features) {
// The following conditions are necessary (but not sufficient) for
// implementing `convolution` with Eigen convolution:
// - the input and kernel have a non-zero number of elements.
@@ -35,6 +52,18 @@ bool PotentiallyImplementedAsEigenConvolution(
// To be sufficient, certain layout constraints need to be satisfied as well.
const Shape& input_shape = convolution.operand(0)->shape();
const Shape& kernel_shape = convolution.operand(1)->shape();
+ const Shape& output_shape = convolution.shape();
+
+ auto is_aligned = [&](const Shape& shape) {
+ return GetMinimumAlignmentForArray(shape, target_machine_features) >=
+ TargetMachineFeatures::kEigenExpectedTensorAlignment;
+ };
+
+ if (!is_aligned(input_shape) || !is_aligned(kernel_shape) ||
+ !is_aligned(output_shape)) {
+ return false;
+ }
+
if (ShapeUtil::HasZeroElements(input_shape) ||
ShapeUtil::HasZeroElements(kernel_shape)) {
return false;
@@ -71,7 +100,6 @@ bool PotentiallyImplementedAsEigenConvolution(
}
}
- const Shape& output_shape = convolution.shape();
return dnums.input_batch_dimension() == 0 &&
dnums.input_feature_dimension() == input_shape.dimensions_size() - 1 &&
dnums.output_batch_dimension() == 0 &&
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h
index 34b2003916..68fbc7caaa 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h
@@ -17,13 +17,20 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_
#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
namespace xla {
namespace cpu {
bool PotentiallyImplementedAsEigenConvolution(
- const HloInstruction& convolution);
+ const HloInstruction& convolution,
+ const TargetMachineFeatures& target_machine_features);
+
+// Computes the minimum alignment guaranteed for a tensor of shape `shape` on
+// the target machine.
+int64 GetMinimumAlignmentForArray(
+ const Shape& shape, const TargetMachineFeatures& target_machine_features);
// Dynamic loop bounds are specified as an array of dimension index
// [start, limit) pairs of ir values (one for each partitioned outer dimension).
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc
index 215f48c4cc..abb2471e6a 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
@@ -39,7 +40,12 @@ ENTRY Conv {
HloComputation* entry_computation = module->entry_computation();
HloInstruction* conv_instr = entry_computation->root_instruction();
- EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution(*conv_instr));
+ cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features(
+ [](int64 shape_size) {
+ return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
+ });
+ EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution(
+ *conv_instr, target_machine_features));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 55e5aa5063..44cf9ac110 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -83,7 +83,7 @@ IrEmitter::IrEmitter(
llvm::Module* llvm_module,
std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx,
std::unordered_map<const HloComputation*, int64> computation_to_profile_idx,
- llvm::TargetMachine* target_machine,
+ const TargetMachineFeatures* target_machine_features,
ExternalConstantPool* external_constant_pool)
: assignment_(assignment),
module_(llvm_module),
@@ -94,7 +94,7 @@ IrEmitter::IrEmitter(
alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
hlo_module_config_(hlo_module.config()),
is_top_level_computation_(false),
- target_machine_features_(target_machine),
+ target_machine_features_(*target_machine_features),
external_constant_pool_(external_constant_pool) {
ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config_.debug_options()
@@ -227,32 +227,6 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) {
}
}
-// Calculate the alignment of a buffer with a particular size.
-int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) {
- // GLibc returns a pointer with alignment 8 on 32-bit platforms and 16 on
- // 64-bit platforms. TCMalloc returns a pointer with alignment 8 for
- // allocations smaller than kMallocAlignmentThreshold bytes and at least
- // alignment 16 for allocations greater than or equal to
- // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound
- // by explicitly allocating the memory with posix_memalign. This is
- // complicated by our desire to allow parameter buffers created by clients to
- // be consumed directly by the JIT.
- if (buffer_size == 0) {
- // No need to align empty buffers.
- return 1;
- }
-
- const int64 kMallocAlignmentThreshold = 512;
-
- int pointer_size = module_->getDataLayout().getPointerSize();
- int buffer_alignment = buffer_size >= kMallocAlignmentThreshold
- ? 2 * pointer_size
- : pointer_size;
- DCHECK_GT(buffer_alignment, 0);
-
- return buffer_alignment;
-}
-
// Calculate the alignment of a buffer allocated for a given primitive type.
int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) {
int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
@@ -277,7 +251,7 @@ int IrEmitter::MinimumAlignmentForShape(const Shape& shape) {
DCHECK_GE(buffer_size, 0);
DCHECK_LE(buffer_size, SIZE_MAX);
- return MinimumAlignmentForBufferSize(buffer_size);
+ return target_machine_features_.minimum_alignment_for_allocation(buffer_size);
}
void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
@@ -290,7 +264,8 @@ void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
int64 buffer_size) {
- int alignment = MinimumAlignmentForBufferSize(buffer_size);
+ int alignment =
+ target_machine_features_.minimum_alignment_for_allocation(buffer_size);
if (alignment > 1) {
llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
}
@@ -861,7 +836,8 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
// TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support
// different data layouts.
- if (PotentiallyImplementedAsEigenConvolution(*convolution)) {
+ if (PotentiallyImplementedAsEigenConvolution(*convolution,
+ target_machine_features_)) {
const Shape& lhs_shape = lhs->shape();
const Shape& rhs_shape = rhs->shape();
const Shape& convolution_shape = convolution->shape();
@@ -1027,12 +1003,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
// We will accumulate the products into this sum to calculate
// the output entry at the given index.
PrimitiveType lhs_element_type = lhs->shape().element_type();
+ llvm::Type* lhs_llvm_type =
+ llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_),
- "convolution_sum_address", &ir_builder_,
+ lhs_llvm_type, "convolution_sum_address", &ir_builder_,
MinimumAlignmentForPrimitiveType(lhs_element_type));
- ir_builder_.CreateStore(
- llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), sum_address);
+ llvm::Value* constant_zero =
+ llvm::Constant::getNullValue(lhs_llvm_type);
+ ir_builder_.CreateStore(constant_zero, sum_address);
llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_);
std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 5a04076080..f49cfc1dc3 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -76,7 +76,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
instruction_to_profile_idx,
std::unordered_map<const HloComputation*, int64>
computation_to_profile_idx,
- llvm::TargetMachine* target_machine,
+ const TargetMachineFeatures* target_machine,
ExternalConstantPool* external_constant_pool);
~IrEmitter() override;
@@ -514,9 +514,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// Calculate the alignment of a buffer allocated for a given primitive type.
int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type);
- // Calculate the alignment of a buffer with a particular size.
- int MinimumAlignmentForBufferSize(int64 buffer_size);
-
// Returns the number of bytes within the shape.
int64 ByteSizeOf(const Shape& shape) const;
@@ -536,7 +533,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
bool is_top_level_computation_;
- TargetMachineFeatures target_machine_features_;
+ const TargetMachineFeatures& target_machine_features_;
int64 external_global_constant_counter_ = 0;
ExternalConstantPool* external_constant_pool_;
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index 47e8405ff2..63d0f7b95c 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -104,7 +104,9 @@ class DefaultCostModel : public ParallelCostModel {
ParallelTaskAssignment::ParallelTaskAssignment(
const int64 max_parallelism,
- const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module) {
+ const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module,
+ const TargetMachineFeatures* target_machine_features)
+ : target_machine_features_(*target_machine_features) {
VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
// Run cost analysis on 'module'.
auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size);
@@ -139,8 +141,10 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
(opcode == HloOpcode::kConvolution &&
- PotentiallyImplementedAsEigenConvolution(*instruction)) ||
- PotentiallyImplementedAsEigenDot(*instruction) ||
+ PotentiallyImplementedAsEigenConvolution(*instruction,
+ target_machine_features_)) ||
+ PotentiallyImplementedAsEigenDot(*instruction,
+ target_machine_features_) ||
(opcode == HloOpcode::kFusion &&
instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) ||
ShapeUtil::IsTuple(instruction->shape())) {
@@ -231,7 +235,8 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper(
void ParallelTaskAssigner::ComputeTargetParallelTasks(
HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) {
ParallelTaskAssignment parallel_task_assignment(max_parallelism_,
- shape_size_function_, module);
+ shape_size_function_, module,
+ &target_machine_features_);
// Compute parallel task counts for all instructions in 'module'.
for (auto* computation : module->computations()) {
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
index 7140dabe51..8becc8fa23 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@@ -39,7 +40,8 @@ class ParallelTaskAssignment {
// 'module': the containing HloModule.
ParallelTaskAssignment(const int64 max_parallelism,
const HloCostAnalysis::ShapeSizeFunction& shape_size,
- HloModule* module);
+ HloModule* module,
+ const TargetMachineFeatures* target_machine_features);
~ParallelTaskAssignment() {}
// Computes and returns the target parallel task count for 'instruction'.
@@ -47,6 +49,7 @@ class ParallelTaskAssignment {
private:
std::unique_ptr<ParallelCostModel> cost_model_;
+ const TargetMachineFeatures& target_machine_features_;
};
// ParallelTaskAssigner computes target parallel task counts for all HLOs
@@ -63,8 +66,11 @@ class ParallelTaskAssigner : public HloPassInterface {
// 'shape_size': shape size function used by HloCostAnalysis during parallel
// task assignment.
ParallelTaskAssigner(const int64 max_parallelism,
- const HloCostAnalysis::ShapeSizeFunction& shape_size)
- : max_parallelism_(max_parallelism), shape_size_function_(shape_size) {}
+ const HloCostAnalysis::ShapeSizeFunction& shape_size,
+ const TargetMachineFeatures* target_machine_features)
+ : max_parallelism_(max_parallelism),
+ shape_size_function_(shape_size),
+ target_machine_features_(*target_machine_features) {}
~ParallelTaskAssigner() override {}
tensorflow::StringPiece name() const override {
@@ -94,6 +100,7 @@ class ParallelTaskAssigner : public HloPassInterface {
int64 max_parallelism_;
HloCostAnalysis::ShapeSizeFunction shape_size_function_;
+ const TargetMachineFeatures& target_machine_features_;
};
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
index 13eb75a572..fc2efbaf9a 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -31,6 +32,19 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase {
// Use any value larger than 2 since we only test whether a module is
// parallelized or not
const int max_parallelism_ = 10;
+
+ cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_;
+
+ ParallelTaskAssignmentTest()
+ : target_machine_features_([](int64 shape_size) {
+ return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
+ }) {}
+
+ StatusOr<bool> RunParallelTaskAssigner(HloModule* module) {
+ return cpu::ParallelTaskAssigner(max_parallelism_, shape_size_func_,
+ &target_machine_features_)
+ .Run(module);
+ }
};
TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) {
@@ -45,9 +59,7 @@ TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) {
)";
ParseAndVerifyModule(hlo_string);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner(
- max_parallelism_, shape_size_func_)
- .Run(&module()));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module()));
EXPECT_FALSE(changed);
}
@@ -74,9 +86,7 @@ TEST_F(ParallelTaskAssignmentTest,
)";
ParseAndVerifyModule(hlo_string);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner(
- max_parallelism_, shape_size_func_)
- .Run(&module()));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module()));
EXPECT_FALSE(changed);
}
@@ -92,9 +102,7 @@ TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) {
)";
ParseAndVerifyModule(hlo_string);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner(
- max_parallelism_, shape_size_func_)
- .Run(&module()));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module()));
EXPECT_FALSE(changed);
}
@@ -108,9 +116,7 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) {
)";
ParseAndVerifyModule(hlo_string);
- TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner(
- max_parallelism_, shape_size_func_)
- .Run(&module()));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module()));
EXPECT_FALSE(changed);
}
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index ff6f0a9d4e..62c97e5641 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -73,20 +73,29 @@ llvm::StringRef GetHostCpuName() {
}
} // namespace
+/*static*/ std::unique_ptr<llvm::TargetMachine>
+SimpleOrcJIT::InferTargetMachineForJIT(
+ const llvm::TargetOptions& target_options,
+ llvm::CodeGenOpt::Level opt_level) {
+ std::unique_ptr<llvm::TargetMachine> target_machine(
+ llvm::EngineBuilder()
+ .setTargetOptions(target_options)
+ .setOptLevel(opt_level)
+ .selectTarget(
+ /*TargetTriple=*/llvm::Triple(), /*MArch=*/"",
+ /*MCPU=*/GetHostCpuName(),
+ /*MAttrs=*/DetectMachineAttributes()));
+ CHECK(target_machine != nullptr);
+ return target_machine;
+}
+
SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
llvm::CodeGenOpt::Level opt_level,
bool optimize_for_size, bool enable_fast_math,
bool disable_expensive_passes,
LLVMCompiler::ModuleHook pre_optimization_hook,
LLVMCompiler::ModuleHook post_optimization_hook)
- : target_machine_(
- CHECK_NOTNULL(llvm::EngineBuilder()
- .setTargetOptions(target_options)
- .setOptLevel(opt_level)
- .selectTarget(
- /*TargetTriple=*/llvm::Triple(), /*MArch=*/"",
- /*MCPU=*/GetHostCpuName(),
- /*MAttrs=*/DetectMachineAttributes()))),
+ : target_machine_(InferTargetMachineForJIT(target_options, opt_level)),
disassembler_(*target_machine_),
data_layout_(target_machine_->createDataLayout()),
symbol_resolver_(llvm::orc::createLegacyLookupResolver(
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h
index f4260a95bc..1851a3ee0b 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h
@@ -95,6 +95,12 @@ class SimpleOrcJIT {
return &external_constant_pool_;
}
+ // Creates an llvm::TargetMachine suitable for JITting code that will run on
+ // the current machine.
+ static std::unique_ptr<llvm::TargetMachine> InferTargetMachineForJIT(
+ const llvm::TargetOptions& target_options,
+ llvm::CodeGenOpt::Level opt_level);
+
private:
llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name);
diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc
index eeb049737d..a0cd8ee2d2 100644
--- a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc
+++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc
@@ -18,7 +18,7 @@ limitations under the License.
namespace xla {
namespace cpu {
-llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor(
+llvm::TargetTransformInfo* LLVMTargetMachineFeatures::GetTargetTransformInfoFor(
const llvm::Function& function) const {
auto it = target_transform_info_cache_.find(&function);
if (it == target_transform_info_cache_.end()) {
@@ -31,5 +31,30 @@ llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor(
return &it->second;
}
+int64 LLVMTargetMachineFeatures::minimum_alignment_for_allocation(
+ int64 size_bytes) const {
+ // GLibc malloc returns a pointer with alignment 8 on 32-bit platforms and 16
+ // on 64-bit platforms. TCMalloc returns a pointer with alignment 8 for
+ // allocations smaller than kMallocAlignmentThreshold bytes and at least
+ // alignment 16 for allocations greater than or equal to
+ // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound
+ // by explicitly allocating the memory with posix_memalign. This is
+ // complicated by our desire to allow parameter buffers created by clients to
+ // be consumed directly by the JIT.
+ if (size_bytes == 0) {
+ // No need to align empty buffers.
+ return 1;
+ }
+
+ const int64 kMallocAlignmentThreshold = 512;
+
+ int pointer_size = target_machine_->getPointerSize(0);
+ int buffer_alignment =
+ size_bytes >= kMallocAlignmentThreshold ? 2 * pointer_size : pointer_size;
+ DCHECK_GT(buffer_alignment, 0);
+
+ return buffer_alignment;
+}
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h
index 703942615e..8b00ae9e47 100644
--- a/tensorflow/compiler/xla/service/cpu/target_machine_features.h
+++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h
@@ -24,43 +24,68 @@ limitations under the License.
namespace xla {
namespace cpu {
-// Wraps an llvm::TargetMachine and parses out some information that feeds into
-// LLVM IR code generation decisions.
+// Abstract interface for classes providing information about the target we're
+// compiling for.
class TargetMachineFeatures {
public:
static constexpr int kX86AvxVectorByteSize = 32;
- TargetMachineFeatures(llvm::TargetMachine* target_machine)
- : target_machine_(target_machine) {}
+ // Input and output tensor buffers must be aligned to this many bytes if we
+ // want to call an Eigen backed GEMM or Convolution.
+ static constexpr int kEigenExpectedTensorAlignment = 16;
// Return the vectorization factor, which is the number of bytes of data
// explicitly vectorized routines will try to process at once.
- int vectorization_factor_in_bytes() const {
- // Ideally this should be a function of the cache line size (which we can
- // get from llvm::TargetTransformInfo::getCacheLineSize) of the target
- // machine. Guess a value of 128 bytes for now.
- return 128;
- }
+ virtual int vectorization_factor_in_bytes() const = 0;
// Return the size of the largest vector size in bytes. We need to pass in
// "function" since llvm functions can contain annotations for specializing
// them to specific micro-architectures (though currently XLA does not use
// this functionality).
- int vector_register_byte_size(const llvm::Function& function) const {
- llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function);
- return tti->getRegisterBitWidth(/*Vector=*/true) / 8;
- }
+ virtual int vector_register_byte_size(
+ const llvm::Function& function) const = 0;
// Return the number of elements of type `type` that can fit into the largest
// vector register available. We need to pass in "function" since llvm
// functions can contain annotations for specializing them to specific
// micro-architectures (though currently XLA does not use this functionality).
+ virtual int vector_register_num_elements(const llvm::Function& function,
+ PrimitiveType type) const = 0;
+
+ // Returns the minimum alignment for a buffer of size size_bytes.
+ virtual int64 minimum_alignment_for_allocation(int64 size_bytes) const = 0;
+
+ virtual ~TargetMachineFeatures() = default;
+};
+
+// Implements the TargetMachineFeatures interface using an llvm::TargetMachine.
+class LLVMTargetMachineFeatures : public TargetMachineFeatures {
+ public:
+ static constexpr int kX86AvxVectorByteSize = 32;
+
+ LLVMTargetMachineFeatures(llvm::TargetMachine* target_machine)
+ : target_machine_(target_machine) {}
+
+ int vectorization_factor_in_bytes() const override {
+ // Ideally this should be a function of the cache line size (which we can
+ // get from llvm::TargetTransformInfo::getCacheLineSize) of the target
+ // machine. Guess a value of 128 bytes for now.
+ return 128;
+ }
+
+ int vector_register_byte_size(const llvm::Function& function) const override {
+ llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function);
+ return tti->getRegisterBitWidth(/*Vector=*/true) / 8;
+ }
+
int vector_register_num_elements(const llvm::Function& function,
- PrimitiveType type) const {
+ PrimitiveType type) const override {
return vector_register_byte_size(function) /
(primitive_util::BitWidth(type) / 8);
}
+ int64 minimum_alignment_for_allocation(int64 size_bytes) const override;
+
private:
llvm::TargetTransformInfo* GetTargetTransformInfoFor(
const llvm::Function& function) const;
diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h
new file mode 100644
index 0000000000..ffc6927cbe
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h
@@ -0,0 +1,57 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_
+
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
+
+namespace xla {
+namespace cpu {
+// Delegates calls to minimum_alignment_for_allocation to a user provided
+// std::function, crashes on all other methods.
+//
+// Primarily useful for testing.
+class TargetMachineFeaturesWithFakeAlignmentLogic
+ : public TargetMachineFeatures {
+ public:
+ explicit TargetMachineFeaturesWithFakeAlignmentLogic(
+ std::function<int64(int64)> fake_alignment_logic)
+ : fake_alignment_logic_(std::move(fake_alignment_logic)) {}
+
+ int vectorization_factor_in_bytes() const override {
+ LOG(FATAL) << "Unexpected call to " << __func__;
+ }
+
+ int vector_register_byte_size(const llvm::Function& function) const override {
+ LOG(FATAL) << "Unexpected call to " << __func__;
+ }
+
+ int vector_register_num_elements(const llvm::Function& function,
+ PrimitiveType type) const override {
+ LOG(FATAL) << "Unexpected call to " << __func__;
+ }
+
+ int64 minimum_alignment_for_allocation(int64 size_bytes) const override {
+ return fake_alignment_logic_(size_bytes);
+ }
+
+ private:
+ std::function<int64(int64)> fake_alignment_logic_;
+};
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_
diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc
index 35db4fd2a2..e228bb56bc 100644
--- a/tensorflow/compiler/xla/service/device_memory_allocator.cc
+++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc
@@ -29,7 +29,7 @@ StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
: DeviceMemoryAllocator(platform),
stream_executors_(stream_executors.begin(), stream_executors.end()) {}
-StatusOr<se::DeviceMemoryBase> StreamExecutorMemoryAllocator::Allocate(
+StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor,
GetStreamExecutor(device_ordinal));
@@ -40,22 +40,17 @@ StatusOr<se::DeviceMemoryBase> StreamExecutorMemoryAllocator::Allocate(
tensorflow::strings::HumanReadableNumBytes(size).c_str(), size,
device_ordinal);
}
- return result;
+ return OwningDeviceMemory(result, device_ordinal, this);
}
-tensorflow::Status StreamExecutorMemoryAllocator::Deallocate(
- int device_ordinal, se::DeviceMemoryBase* mem) {
- if (!mem->is_null()) {
+Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
+ se::DeviceMemoryBase mem) {
+ if (!mem.is_null()) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor,
GetStreamExecutor(device_ordinal));
- // We make a local copy of 'mem' so the original is not zeroed out by the
- // Deallocate() call below. This gives us a better chance of
- // catching double-free bugs, since Deallocate silently succeeds for null
- // values.
- se::DeviceMemoryBase mem_copy(*mem);
- stream_executor->Deallocate(&mem_copy);
+ stream_executor->Deallocate(&mem);
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
StatusOr<se::StreamExecutor*> StreamExecutorMemoryAllocator::GetStreamExecutor(
diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h
index da45c4d45a..d87b86caf0 100644
--- a/tensorflow/compiler/xla/service/device_memory_allocator.h
+++ b/tensorflow/compiler/xla/service/device_memory_allocator.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
+#include "tensorflow/compiler/xla/service/owning_device_memory.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -37,28 +38,29 @@ class DeviceMemoryAllocator {
: platform_(platform) {}
virtual ~DeviceMemoryAllocator() {}
+ // Allocates memory on the device.
+ //
+ // If size > 0 and the returned StatusOr is OK, the wrapped OwningDeviceMemory
+ // must not be null. If size == 0, must return a null OwningDeviceMemory.
+ //
// 'retry_on_failure': If false, and the first attempt to allocate the memory
// fails, the allocation should return immediately without retrying. An
// example use case is optional scratch spaces where a failure has only
// performance impact.
- //
- // Allocate() should return a null pointer for a size-0 allocation.
- // Deallocate() must be a no-op for null pointers.
- virtual StatusOr<se::DeviceMemoryBase> Allocate(int device_ordinal,
- uint64 size,
- bool retry_on_failure) = 0;
+ virtual StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
+ bool retry_on_failure) = 0;
// Two-arg version of Allocate(), which sets retry-on-failure to true.
//
// (We don't simply use a default argument on the virtual Allocate function
// because default args on virtual functions are disallowed by the Google
// style guide.)
- StatusOr<se::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size) {
+ StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size) {
return Allocate(device_ordinal, size, /*retry_on_failure=*/true);
}
- virtual tensorflow::Status Deallocate(int device_ordinal,
- se::DeviceMemoryBase* mem) = 0;
+ // Must be a nop for null pointers.
+ virtual Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) = 0;
// Return the platform that the allocator allocates memory on.
const se::Platform* platform() const { return platform_; }
@@ -68,6 +70,7 @@ class DeviceMemoryAllocator {
virtual bool AllowsAsynchronousDeallocation() const = 0;
protected:
+ friend class OwningDeviceMemory;
const se::Platform* platform_;
};
@@ -79,14 +82,13 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
const se::Platform* platform,
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors);
- StatusOr<se::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size,
- bool retry_on_failure) override;
+ StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
+ bool retry_on_failure) override;
// Pull in two-arg overload that sets retry_on_failure to true.
using DeviceMemoryAllocator::Allocate;
- tensorflow::Status Deallocate(int device_ordinal,
- se::DeviceMemoryBase* mem) override;
+ Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
bool AllowsAsynchronousDeallocation() const override;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 0528b07602..b9d7ec9c2e 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -138,6 +138,9 @@ class DfsHloVisitorBase {
virtual Status HandleExp(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo);
}
+ virtual Status HandleExpm1(HloInstructionPtr hlo) {
+ return HandleElementwiseUnary(hlo);
+ }
virtual Status HandleFloor(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo);
}
@@ -150,6 +153,9 @@ class DfsHloVisitorBase {
virtual Status HandleClz(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo);
}
+ virtual Status HandleLog1p(HloInstructionPtr hlo) {
+ return HandleElementwiseUnary(hlo);
+ }
virtual Status HandleCos(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo);
}
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index ae32d33766..0a400e982a 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -418,8 +418,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
case HloOpcode::kExp:
return EmitExp(op->shape().element_type(), operand_value);
+ case HloOpcode::kExpm1:
+ return EmitExpm1(op->shape().element_type(), operand_value);
case HloOpcode::kLog:
return EmitLog(op->shape().element_type(), operand_value);
+ case HloOpcode::kLog1p:
+ return EmitLog1p(op->shape().element_type(), operand_value);
case HloOpcode::kCos:
return EmitCos(op->shape().element_type(), operand_value);
case HloOpcode::kSin:
@@ -493,6 +497,22 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
return EmitComposeComplex(
op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle);
}
+ case HloOpcode::kLog1p: {
+ // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
+ auto a = EmitExtractReal(operand_value);
+ auto b = EmitExtractImag(operand_value);
+ llvm::Type* llvm_ty = a->getType();
+ auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
+ auto a_plus_one = ir_builder_->CreateFAdd(a, one);
+ auto sum_sq = ir_builder_->CreateFAdd(
+ ir_builder_->CreateFMul(a_plus_one, a_plus_one),
+ ir_builder_->CreateFMul(b, b));
+ TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
+ TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one));
+ auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
+ return EmitComposeComplex(
+ op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle);
+ }
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
TF_RET_CHECK(primitive_util::IsComplexType(from_type));
@@ -523,6 +543,20 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
ir_builder_->CreateFMul(exp_a, sin_b));
}
+ case HloOpcode::kExpm1: {
+ // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
+ TF_ASSIGN_OR_RETURN(
+ auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
+ TF_ASSIGN_OR_RETURN(
+ auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
+ TF_ASSIGN_OR_RETURN(
+ auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
+ auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0);
+ auto real_result =
+ ir_builder_->CreateFSub(ir_builder_->CreateFMul(exp_a, cos_b), one);
+ auto imag_result = ir_builder_->CreateFMul(exp_a, sin_b);
+ return EmitComposeComplex(op, real_result, imag_result);
+ }
case HloOpcode::kCos: {
// cos(z) = .5(e^(iz) + e^(-iz))
// cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
@@ -975,6 +1009,28 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
{value->getType()}, ir_builder_);
}
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
+ llvm::Value* value) const {
+ auto x = value;
+ auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
+ auto one = llvm::ConstantFP::get(type, 1.0);
+ auto negative_half = llvm::ConstantFP::get(type, -0.5);
+ // When x is large, the naive evaluation of ln(x + 1) is more
+ // accurate than the Taylor series.
+ TF_ASSIGN_OR_RETURN(auto for_large_x,
+ EmitLog(prim_type, ir_builder_->CreateFAdd(x, one)));
+ // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + ….
+ auto for_small_x = ir_builder_->CreateFMul(
+ ir_builder_->CreateFAdd(ir_builder_->CreateFMul(negative_half, x), one),
+ x);
+ const auto kAntilogarithmIsSmallThreshold = 1e-4;
+ auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value},
+ {type}, ir_builder_);
+ auto x_is_small = ir_builder_->CreateFCmpOLT(
+ abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold));
+ return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x);
+}
+
StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
llvm::Value* value) const {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
@@ -993,6 +1049,29 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
{value->getType()}, ir_builder_);
}
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
+ llvm::Value* value) const {
+ auto x = value;
+ auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
+ auto one = llvm::ConstantFP::get(type, 1.0);
+ auto half = llvm::ConstantFP::get(type, 0.5);
+ // When the exponent is large, the naive evaluation of e^(x) - 1 is more
+ // accurate than the Taylor series.
+ TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value));
+ auto for_large_x = ir_builder_->CreateFSub(exp_x, one);
+ // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + ….
+ // We want exp(x)-1 which is x + x^2/2 + x^3/6 + ….
+ auto x_squared = ir_builder_->CreateFAdd(x, x);
+ auto x_squared_over_two = ir_builder_->CreateFMul(x_squared, half);
+ auto for_small_x = ir_builder_->CreateFAdd(x, x_squared_over_two);
+ const auto kExponentIsSmallThreshold = 1e-5;
+ auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value},
+ {type}, ir_builder_);
+ auto x_is_small = ir_builder_->CreateFCmpOLT(
+ abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold));
+ return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x);
+}
+
StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) const {
@@ -1784,8 +1863,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
const llvm_ir::IrArray::Index& dot_result_index) const {
auto lhs_generator = operand_to_generator.at(hlo->operand(0));
auto rhs_generator = operand_to_generator.at(hlo->operand(1));
- int64 contracted_dim_size = hlo->operand(0)->shape().dimensions(
- hlo->operand(0)->shape().dimensions_size() - 1);
+
+ const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers();
+ int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0);
+ int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0);
+
+ int64 contracted_dim_size =
+ hlo->operand(0)->shape().dimensions(lhs_contracting_dim);
int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
@@ -1816,13 +1900,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
for (int64 i = 0; i < lhs_dims - 1; i++) {
lhs_index.push_back(dot_result_index[i]);
}
- lhs_index.push_back(inner_loop->GetIndVarValue());
+ lhs_index.InsertAt(lhs_contracting_dim, inner_loop->GetIndVarValue());
- for (int64 i = 0; i < rhs_dims - 2; i++) {
+ for (int64 i = 0; i < rhs_dims - 1; i++) {
rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]);
}
- rhs_index.push_back(inner_loop->GetIndVarValue());
- rhs_index.push_back(dot_result_index.back());
+ rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue());
llvm::Value* current_accumulator =
ir_builder_->CreateLoad(accumulator_alloca);
@@ -1877,10 +1960,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kReal:
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 26dff0d96f..d199473374 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -105,6 +105,9 @@ class ElementalIrEmitter {
virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
llvm::Value* value) const;
+ virtual StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type,
+ llvm::Value* value) const;
+
virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
llvm::Value* value) const;
@@ -114,6 +117,9 @@ class ElementalIrEmitter {
virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
llvm::Value* value) const;
+ virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
+ llvm::Value* value) const;
+
virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) const;
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
new file mode 100644
index 0000000000..b43dc0c65d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+
+namespace xla {
+namespace {
+
+using tensorflow::gtl::nullopt;
+
+class ElementalIrEmitterExecutionTest : public HloTestBase {
+ protected:
+ void RunTest(const string& hlo_text,
+ tensorflow::gtl::ArraySlice<Literal*> args) {
+ HloModuleConfig config;
+ config.set_debug_options(GetDebugOptionsForTest());
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_text, config));
+ EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt));
+ }
+};
+
+XLA_TEST_F(ElementalIrEmitterExecutionTest, DotFusion) {
+ const string hlo_text = R"(
+HloModule FusedDot
+
+fused_computation {
+ arg0 = s32[1,2,1]{2,1,0} parameter(0)
+ reshape.lhs = s32[2,1]{1,0} reshape(arg0)
+ arg1 = s32[1,2,1]{2,1,0} parameter(1)
+ reshape.rhs = s32[2,1]{1,0} reshape(arg1)
+ ROOT dot = s32[1,1]{1,0} dot(reshape.lhs, reshape.rhs), lhs_contracting_dims={0}, rhs_contracting_dims={0}
+}
+
+ENTRY main {
+ entry_arg0 = s32[1,2,1]{2,1,0} parameter(0)
+ entry_arg1 = s32[1,2,1]{2,1,0} parameter(1)
+ ROOT fusion = s32[1,1]{1,0} fusion(entry_arg0, entry_arg1), kind=kLoop, calls=fused_computation
+}
+)";
+
+ std::unique_ptr<Literal> lhs = Literal::CreateR3<int32>({{{1}, {2}}});
+ std::unique_ptr<Literal> rhs = Literal::CreateR3<int32>({{{3}, {4}}});
+ RunTest(hlo_text, {lhs.get(), rhs.get()});
+}
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc
index 2f0b9ed2bd..6794cfe297 100644
--- a/tensorflow/compiler/xla/service/execution_tracker.cc
+++ b/tensorflow/compiler/xla/service/execution_tracker.cc
@@ -37,11 +37,11 @@ AsyncExecution::AsyncExecution(Backend* backend,
}
}
-tensorflow::Status AsyncExecution::BlockUntilDone() const {
+Status AsyncExecution::BlockUntilDone() const {
for (auto& stream : streams_) {
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
ExecutionTracker::ExecutionTracker() : next_handle_(1) {}
@@ -61,7 +61,7 @@ ExecutionHandle ExecutionTracker::Register(
return execution_handle;
}
-tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) {
+Status ExecutionTracker::Unregister(const ExecutionHandle& handle) {
tensorflow::mutex_lock lock(execution_mutex_);
auto it = handle_to_execution_.find(handle.handle());
if (it == handle_to_execution_.end()) {
@@ -69,7 +69,7 @@ tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) {
handle.handle());
}
handle_to_execution_.erase(handle.handle());
- return tensorflow::Status::OK();
+ return Status::OK();
}
StatusOr<const AsyncExecution*> ExecutionTracker::Resolve(
diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h
index 5b6bddf9f1..4458152dd9 100644
--- a/tensorflow/compiler/xla/service/execution_tracker.h
+++ b/tensorflow/compiler/xla/service/execution_tracker.h
@@ -43,7 +43,7 @@ class AsyncExecution {
AsyncExecution(Backend* backend, std::vector<Backend::StreamPtr> streams,
const ExecutionProfile& profile, GlobalDataHandle result);
- tensorflow::Status BlockUntilDone() const;
+ Status BlockUntilDone() const;
const GlobalDataHandle& result() const { return result_; }
@@ -77,7 +77,7 @@ class ExecutionTracker {
GlobalDataHandle data);
// Unregisters the execution for the given handle.
- tensorflow::Status Unregister(const ExecutionHandle& handle);
+ Status Unregister(const ExecutionHandle& handle);
// Resolves the given ExecutionHandle to an AsyncExecution. Returns an
// error status if the given handle is not found, which means that the
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index ddb687314e..dbf1ab6690 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -115,7 +115,7 @@ Status GenericTransferManager::TransferLiteralToDevice(
TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
device_memory.size());
// Element is array-shaped: transfer array data to device buffer.
- const auto subliteral = LiteralView::Create(literal, index);
+ const auto subliteral = LiteralSlice(literal, index);
std::unique_ptr<Literal> relayed_out_literal;
const void* source;
if (LayoutUtil::Equal(device_subshape.layout(),
@@ -137,7 +137,7 @@ Status GenericTransferManager::TransferLiteralToDevice(
}
Status GenericTransferManager::TransferLiteralToInfeed(
- se::StreamExecutor* executor, const Literal& literal) {
+ se::StreamExecutor* executor, const LiteralSlice& literal) {
return Unimplemented("Generic transfer to Infeed");
}
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
index 0579099de4..3343eca851 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -49,7 +49,7 @@ class GenericTransferManager : public TransferManager {
const ShapedBuffer& device_buffer) override;
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
- const Literal& literal) override;
+ const LiteralSlice& literal) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
const Shape& literal_shape,
Literal* literal) override;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 7cb7f55073..7ee039b3eb 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -388,8 +388,10 @@ cc_library(
deps = [
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:instruction_fusion",
+ "//tensorflow/compiler/xla/service:pattern_matcher",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
index 837f05244f..ab5149dcdb 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
@@ -37,11 +37,11 @@ void BufferAllocations::Builder::RegisterBuffer(BufferAllocation::Index index,
}
StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
- const BufferAssignment& buffer_assignment, int device_ordinal,
+ const BufferAssignment* buffer_assignment, int device_ordinal,
DeviceMemoryAllocator* memory_allocator) {
- const int64 num_buffers = buffer_assignment.Allocations().size();
- auto buffer_allocations = WrapUnique(
- new BufferAllocations(num_buffers, device_ordinal, memory_allocator));
+ const int64 num_buffers = buffer_assignment->Allocations().size();
+ auto buffer_allocations = WrapUnique(new BufferAllocations(
+ num_buffers, device_ordinal, memory_allocator, buffer_assignment));
for (BufferAllocation::Index i = 0; i < num_buffers; ++i) {
// If buffer #i's address is already registered (e.g. external arguments or
@@ -62,28 +62,28 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
// Allocate each allocation that might escape, or is the temp buffer.
bool seen_temp_buffer = false;
- const BufferAllocation& allocation = buffer_assignment.GetAllocation(i);
+ const BufferAllocation& allocation = buffer_assignment->GetAllocation(i);
if (allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()) {
const int64 buffer_size = allocation.size();
se::DeviceMemoryBase buffer_address;
if (buffer_size > 0) {
- TF_ASSIGN_OR_RETURN(buffer_address, memory_allocator->Allocate(
- device_ordinal, buffer_size));
- if (buffer_address == nullptr) {
- return ResourceExhausted(
- "Out of memory when allocating %s for buffer %lld.",
- tensorflow::strings::HumanReadableNumBytes(buffer_size).c_str(),
- i);
- }
- if (reinterpret_cast<uintptr_t>(buffer_address.opaque()) %
+ OwningDeviceMemory buffer;
+ TF_ASSIGN_OR_RETURN(
+ buffer, memory_allocator->Allocate(device_ordinal, buffer_size));
+ if (reinterpret_cast<uintptr_t>(buffer.opaque()) %
kCudaMallocAlignBytes !=
0) {
return InternalError(
"Address returned by memory_allocator->Allocate must be a "
"multiple of %llx, but was %p",
- kCudaMallocAlignBytes, buffer_address.opaque());
+ kCudaMallocAlignBytes, buffer.opaque());
}
+ // We do manual memory management within BufferAllocations. Be sure not
+ // to do a TF_RETURN_IF_ERROR between this line and the
+ // buffer_allocations->SetBuffer(buffer_address) call below!
+ buffer_address = buffer.Forget();
}
+
buffer_allocations->SetBuffer(i, buffer_address);
if (allocation.IsPreallocatedTempBuffer()) {
if (seen_temp_buffer) {
@@ -103,28 +103,42 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
<< "B)";
}
}
-
return std::move(buffer_allocations);
}
-tensorflow::Status BufferAllocations::TearDown(
- const std::set<se::DeviceMemoryBase>& live_addresses,
- const BufferAssignment& buffer_assignment) {
- // Deallocate temporary buffers.
- const int64 num_buffers = buffer_assignment.Allocations().size();
+BufferAllocations::~BufferAllocations() {
+ if (!torn_down_) {
+ // Presumably if we're executing this branch, the caller is in an error
+ // state, otherwise it would have explicitly called TearDown so it could
+ // save some set of live addresses. So ignoring any errors in TearDown is
+ // sensible.
+ TearDown(/*live_addresses=*/{}).IgnoreError();
+ }
+}
+
+Status BufferAllocations::TearDown(
+ const std::set<se::DeviceMemoryBase>& live_addresses) {
+ // Deallocate temporary buffers, taking care to try to deallocate all of them
+ // even if one of the deallocations fails.
+ Status status;
+ const int64 num_buffers = buffer_assignment_->Allocations().size();
for (BufferAllocation::Index i = 0; i < num_buffers; ++i) {
- const BufferAllocation& allocation = buffer_assignment.GetAllocation(i);
+ const BufferAllocation& allocation = buffer_assignment_->GetAllocation(i);
se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index());
// Deallocate buffers marked "maybe_live_out" but aren't actually live out,
// and temp buffers.
if ((allocation.maybe_live_out() &&
!live_addresses.count(buffer_address)) ||
allocation.IsPreallocatedTempBuffer()) {
- TF_RETURN_IF_ERROR(
- memory_allocator_->Deallocate(device_ordinal_, &buffer_address));
+ auto dealloc_result =
+ memory_allocator_->Deallocate(device_ordinal_, buffer_address);
+ if (!dealloc_result.ok() && status.ok()) {
+ status = dealloc_result;
+ }
}
}
- return tensorflow::Status::OK();
+ torn_down_ = true;
+ return status;
}
se::DeviceMemoryBase BufferAllocations::GetDeviceAddress(
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
index c2fc35be4c..6366235025 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
@@ -48,13 +48,15 @@ class BufferAllocations {
// `device_ordinal` is the number of the device this function allocates
// memory on.
StatusOr<std::unique_ptr<BufferAllocations>> Build(
- const BufferAssignment& buffer_assignment, int device_ordinal,
+ const BufferAssignment* buffer_assignment, int device_ordinal,
DeviceMemoryAllocator* memory_allocator);
private:
std::map<BufferAllocation::Index, se::DeviceMemoryBase> registered_buffers_;
};
+ ~BufferAllocations();
+
BufferAllocations(const BufferAllocations&) = delete;
BufferAllocations& operator=(const BufferAllocations&) = delete;
@@ -76,16 +78,16 @@ class BufferAllocations {
// Tears down all buffers allocated by this object that are not in
// `live_addresses`.
- tensorflow::Status TearDown(
- const std::set<se::DeviceMemoryBase>& live_addresses,
- const BufferAssignment& buffer_assignment);
+ Status TearDown(const std::set<se::DeviceMemoryBase>& live_addresses);
private:
BufferAllocations(BufferAllocation::Index buffer_count, int device_ordinal,
- DeviceMemoryAllocator* memory_allocator)
+ DeviceMemoryAllocator* memory_allocator,
+ const BufferAssignment* buffer_assignment)
: buffers_(buffer_count),
device_ordinal_(device_ordinal),
- memory_allocator_(memory_allocator) {}
+ memory_allocator_(memory_allocator),
+ buffer_assignment_(buffer_assignment) {}
// Sets the device address of buffer `buffer_index`.
void SetBuffer(BufferAllocation::Index buffer_index,
@@ -100,8 +102,9 @@ class BufferAllocations {
se::DeviceMemoryBase temp_buffer_base_;
int device_ordinal_;
-
DeviceMemoryAllocator* memory_allocator_;
+ const BufferAssignment* buffer_assignment_;
+ bool torn_down_ = false;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
index dce8de2e30..77a48965e0 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
@@ -35,9 +35,10 @@ ConditionalThunk::ConditionalThunk(
true_thunk_(std::move(true_thunk_sequence), hlo),
false_thunk_(std::move(false_thunk_sequence), hlo) {}
-Status ConditionalThunk::Initialize(const GpuExecutable& executable) {
- TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable));
- TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable));
+Status ConditionalThunk::Initialize(const GpuExecutable& executable,
+ se::StreamExecutor* executor) {
+ TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable, executor));
+ TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable, executor));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
index e40872688f..ee03865d17 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
@@ -47,7 +47,8 @@ class ConditionalThunk : public Thunk {
ConditionalThunk(const ConditionalThunk&) = delete;
ConditionalThunk& operator=(const ConditionalThunk&) = delete;
- Status Initialize(const GpuExecutable& executable) override;
+ Status Initialize(const GpuExecutable& executable,
+ se::StreamExecutor* executor) override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream) override;
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 64d3b84b8c..f088112412 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -29,11 +29,6 @@ namespace xla {
namespace gpu {
using se::dnn::AlgorithmDesc;
-using se::dnn::BatchDescriptor;
-using se::dnn::ConvolutionDescriptor;
-using se::dnn::DataLayout;
-using se::dnn::FilterDescriptor;
-using se::dnn::FilterLayout;
ConvolutionThunk::ConvolutionThunk(
CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer,
diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc
index bf912fbd14..ee38c0318a 100644
--- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc
@@ -29,12 +29,12 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk(
destination_buffer_(destination_buffer),
mem_size_(mem_size) {}
-tensorflow::Status HostToDeviceCopyThunk::ExecuteOnStream(
+Status HostToDeviceCopyThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream) {
se::DeviceMemoryBase destination_data =
buffer_allocations.GetDeviceAddress(destination_buffer_);
stream->ThenMemcpy(&destination_data, source_address_, mem_size_);
- return tensorflow::Status::OK();
+ return Status::OK();
}
DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk(
@@ -46,14 +46,14 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk(
destination_buffer_(destination_buffer),
mem_size_(mem_size) {}
-tensorflow::Status DeviceToDeviceCopyThunk::ExecuteOnStream(
+Status DeviceToDeviceCopyThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream) {
se::DeviceMemoryBase destination_data =
buffer_allocations.GetDeviceAddress(destination_buffer_);
se::DeviceMemoryBase source_data =
buffer_allocations.GetDeviceAddress(source_buffer_);
stream->ThenMemcpy(&destination_data, source_data, mem_size_);
- return tensorflow::Status::OK();
+ return Status::OK();
}
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h
index 2e7eb5f344..8b128386f6 100644
--- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h
@@ -39,8 +39,8 @@ class HostToDeviceCopyThunk : public Thunk {
HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete;
HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete;
- tensorflow::Status ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) override;
+ Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream) override;
private:
const void* source_address_;
@@ -62,8 +62,8 @@ class DeviceToDeviceCopyThunk : public Thunk {
DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete;
DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete;
- tensorflow::Status ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) override;
+ Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream) override;
private:
const BufferAllocation::Slice source_buffer_;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 41ee45f55f..6a46bdb9b4 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -35,35 +35,22 @@ class ScratchAllocator : public se::ScratchAllocator {
ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator)
: device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
- ~ScratchAllocator() override;
-
int64 GetMemoryLimitInBytes(se::Stream* stream) override {
return 1LL << 32; // 4GB. TODO(jlebar): Tune this?
}
int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
- se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
- se::Stream* stream, int64 byte_size) override;
+ StatusOr<se::DeviceMemory<uint8>> AllocateBytes(se::Stream* stream,
+ int64 byte_size) override;
private:
const int device_ordinal_;
DeviceMemoryAllocator* memory_allocator_;
- std::vector<se::DeviceMemoryBase> allocated_buffers_;
+ std::vector<OwningDeviceMemory> allocated_buffers_;
int64 total_allocated_bytes_ = 0;
};
-ScratchAllocator::~ScratchAllocator() {
- for (auto& allocated_buffer : allocated_buffers_) {
- if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer)
- .ok()) {
- // The program can still continue with failed deallocation.
- LOG(ERROR) << "Failed to deallocate the allocated buffer: "
- << allocated_buffer.opaque();
- }
- }
-}
-
-se::port::StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
+StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
se::Stream* stream, int64 byte_size) {
CHECK_GE(byte_size, 0) << "byte_size must be positive.";
if (byte_size > GetMemoryLimitInBytes(stream)) {
@@ -74,19 +61,14 @@ se::port::StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
byte_size, GetMemoryLimitInBytes(stream)));
}
- auto status_or_memory =
- memory_allocator_->Allocate(device_ordinal_, byte_size,
- /*retry_on_failure=*/false);
- if (!status_or_memory.ok()) {
- return se::port::Status(se::port::error::RESOURCE_EXHAUSTED,
- tensorflow::strings::Printf(
- "Failed to allocate %lld bytes on device %d.",
- byte_size, device_ordinal_));
- }
- se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie();
- allocated_buffers_.push_back(allocated_buffer);
+ TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer,
+ memory_allocator_->Allocate(device_ordinal_, byte_size,
+ /*retry_on_failure=*/false));
total_allocated_bytes_ += byte_size;
- return se::DeviceMemory<uint8>(allocated_buffer);
+
+ se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase();
+ allocated_buffers_.push_back(std::move(allocated_buffer));
+ return se::DeviceMemory<uint8>(buffer_addr);
}
// Determines whether we can safely perform a winograd non-fused convolution for
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 5af7a77ea8..e5e2a0478a 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -227,6 +227,11 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(
return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type);
}
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(
+ PrimitiveType prim_type, llvm::Value* value) const {
+ return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type);
+}
+
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(
PrimitiveType prim_type, llvm::Value* value) const {
return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type);
@@ -242,6 +247,11 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(
return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type);
}
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(
+ PrimitiveType prim_type, llvm::Value* value) const {
+ return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type);
+}
+
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) const {
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
index 77d4569b1e..91f4d960aa 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
@@ -64,6 +64,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
llvm::Value* value) const override;
+ StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type,
+ llvm::Value* value) const override;
+
StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
llvm::Value* value) const override;
@@ -73,6 +76,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
llvm::Value* value) const override;
+ StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
+ llvm::Value* value) const override;
+
StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) const override;
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
index cc747addbd..e14ee6918b 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
@@ -31,23 +31,12 @@ FftScratchAllocator::FftScratchAllocator(
int device_ordinal, DeviceMemoryAllocator* memory_allocator)
: device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
-FftScratchAllocator::~FftScratchAllocator() {
- for (auto& allocated_buffer : allocated_buffers_) {
- if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer)
- .ok()) {
- // The program can still continue with failed deallocation.
- LOG(ERROR) << "Failed to deallocate the allocated buffer: "
- << allocated_buffer.opaque();
- }
- }
-}
-
int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) {
constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default.
return kFftScratchSize;
}
-se::port::StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
+StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
se::Stream* stream, int64 byte_size) {
CHECK_GE(byte_size, 0) << "byte_size must be positive.";
if (byte_size > GetMemoryLimitInBytes(stream)) {
@@ -58,18 +47,14 @@ se::port::StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
byte_size, GetMemoryLimitInBytes(stream)));
}
- auto status_or_memory =
- memory_allocator_->Allocate(device_ordinal_, byte_size,
- /*retry_on_failure=*/false);
- if (!status_or_memory.ok()) {
- return tensorflow::errors::ResourceExhausted(
- "Failed to allocate %lld bytes on device %d.", byte_size,
- device_ordinal_);
- }
- se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie();
- allocated_buffers_.push_back(allocated_buffer);
+ TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer,
+ memory_allocator_->Allocate(device_ordinal_, byte_size,
+ /*retry_on_failure=*/false));
total_allocated_bytes_ += byte_size;
- return se::DeviceMemory<uint8>(allocated_buffer);
+
+ se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase();
+ allocated_buffers_.push_back(std::move(allocated_buffer));
+ return se::DeviceMemory<uint8>(buffer_addr);
}
namespace {
@@ -121,8 +106,8 @@ FftThunk::FftThunk(FftType fft_type,
input_shape_(input_shape),
output_shape_(output_shape) {}
-tensorflow::Status FftThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream) {
VLOG(3) << "FFT type: " << FftTypeToString(fft_type_);
VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_);
VLOG(3) << "Output shape: "
@@ -222,7 +207,7 @@ tensorflow::Status FftThunk::ExecuteOnStream(
LOG(FATAL) << "unsupported fft type";
}
if (launch_ok) {
- return tensorflow::Status::OK();
+ return Status::OK();
}
return InternalError("Unable to launch fft for thunk %p with type %s", this,
FftTypeToString(fft_type_).c_str());
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
index 24b1dca998..b0a22564f3 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
@@ -39,8 +39,6 @@ class FftScratchAllocator : public se::ScratchAllocator {
FftScratchAllocator(int device_ordinal,
DeviceMemoryAllocator* memory_allocator);
- ~FftScratchAllocator() override;
-
int64 GetMemoryLimitInBytes(se::Stream* stream) override;
int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
@@ -51,7 +49,7 @@ class FftScratchAllocator : public se::ScratchAllocator {
private:
const int device_ordinal_;
DeviceMemoryAllocator* memory_allocator_;
- std::vector<se::DeviceMemoryBase> allocated_buffers_;
+ std::vector<OwningDeviceMemory> allocated_buffers_;
int64 total_allocated_bytes_ = 0;
};
@@ -73,8 +71,8 @@ class FftThunk : public Thunk {
FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_
// Does the FFT for the thunk on "stream".
- tensorflow::Status ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) override;
+ Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream) override;
private:
const se::fft::Type fft_type_;
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
index 6e6966df39..b36539e0cb 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
@@ -30,19 +30,20 @@ ForThunk::ForThunk(const int64 loop_limit,
body_thunk_sequence_(
MakeUnique<SequentialThunk>(std::move(*body_thunk_sequence), hlo)) {}
-tensorflow::Status ForThunk::Initialize(const GpuExecutable& executable) {
- TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable));
- return tensorflow::Status::OK();
+Status ForThunk::Initialize(const GpuExecutable& executable,
+ se::StreamExecutor* executor) {
+ TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor));
+ return Status::OK();
}
-tensorflow::Status ForThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream) {
for (int64 i = 0; i < loop_limit_; ++i) {
// Invoke loop body thunk sequence.
TF_RETURN_IF_ERROR(
body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream));
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h
index c78d1c5068..41ddfe0ceb 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h
@@ -36,9 +36,10 @@ class ForThunk : public Thunk {
ForThunk(const ForThunk&) = delete;
ForThunk& operator=(const ForThunk&) = delete;
- tensorflow::Status Initialize(const GpuExecutable& executable) override;
- tensorflow::Status ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) override;
+ Status Initialize(const GpuExecutable& executable,
+ se::StreamExecutor* executor) override;
+ Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream) override;
private:
const int64 loop_limit_;
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index f996fe486d..79fca43d02 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -215,6 +215,25 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) {
}
}
+DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) {
+ if (hlo_instruction.opcode() == HloOpcode::kDot) {
+ return hlo_instruction.dot_dimension_numbers();
+ }
+ CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion);
+ CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput);
+ CHECK_EQ(hlo_instruction.fused_expression_root()->opcode(),
+ HloOpcode::kMultiply);
+ // Try to find the dot inside the output fusion node.
+ const HloInstruction* dot =
+ hlo_instruction.fused_expression_root()->operand(0);
+ if (dot->opcode() != HloOpcode::kDot) {
+ dot = hlo_instruction.fused_expression_root()->operand(1);
+ }
+ CHECK_EQ(dot->opcode(), HloOpcode::kDot);
+
+ return dot->dot_dimension_numbers();
+}
+
} // namespace
GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
@@ -232,8 +251,8 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
output_shape_(output_shape),
alpha_(alpha) {}
-tensorflow::Status GemmThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream) {
VLOG(2) << "Executing a GemmThunk";
se::DeviceMemoryBase lhs_data =
@@ -281,8 +300,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream(
shape.dimensions(!is_row_major));
};
- const DotDimensionNumbers& dim_nums =
- hlo_instruction()->dot_dimension_numbers();
+ DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
const MatrixDescriptor lhs_descriptor = make_descriptor(
lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0);
@@ -350,7 +368,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream(
if (!launch_ok) {
return InternalError("Unable to launch cuBLAS gemm on stream %p", stream);
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
index f42cbf9e94..7a4830d64e 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
@@ -47,8 +47,8 @@ class GemmThunk : public Thunk {
GemmThunk& operator=(const GemmThunk&) = delete;
// Does the gemm operation for the thunk on "stream", which must be non-null.
- tensorflow::Status ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) override;
+ Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream) override;
// Returns true if we'll perform autotuning if run on the given stream. If
// so, we want the GPU to be quiescent during autotuning, so as not to
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 4fdc4c8961..df494a1aa9 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -128,9 +128,8 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) {
}
// Runs optimization passes on the given HLO module.
-tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
- se::StreamExecutor* stream_exec,
- DeviceMemoryAllocator* device_allocator) {
+Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
+ DeviceMemoryAllocator* device_allocator) {
{
HloPassPipeline pipeline("optimization");
pipeline.AddInvariantChecker<HloVerifier>();
@@ -283,12 +282,12 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
}
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
// Modifies the given HLO module so that it will be accepted by IrEmitter.
// Unlike optimization passes, the passes are necessary for correctness.
-tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
+Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
// In some cases, we have to place the result of an instruction in a temporary
// buffer. For instance, the buffer that holds an external parameter is
// assumed immutable at this point, and should not be reused for output
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 980cc89fa0..f8766474a8 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -134,9 +134,10 @@ Status GpuExecutable::ExecuteThunks(
const BufferAllocations& buffer_allocations, bool block_host_until_done,
HloExecutionProfile* hlo_execution_profile) {
se::Stream* main_stream = run_options->stream();
+ se::StreamExecutor* executor = main_stream->parent();
std::pair<int, int> stream_compute_compatibility;
- main_stream->parent()->GetDeviceDescription().cuda_compute_capability(
+ executor->GetDeviceDescription().cuda_compute_capability(
&stream_compute_compatibility.first,
&stream_compute_compatibility.second);
TF_RET_CHECK(stream_compute_compatibility == compute_capability_)
@@ -155,21 +156,17 @@ Status GpuExecutable::ExecuteThunks(
sub_streams.reserve(thunk_schedule_->StreamCount() - 1);
while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) {
sub_streams.emplace_back();
- TF_ASSIGN_OR_RETURN(
- sub_streams.back(),
- run_options->BorrowStream(main_stream->parent()->device_ordinal()));
+ TF_ASSIGN_OR_RETURN(sub_streams.back(),
+ run_options->BorrowStream(executor->device_ordinal()));
}
HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream,
sub_streams, hlo_module_->entry_computation());
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
- // The next event enqueued on stream N must not run until the thunk at
- // last_blocking_thunk_for_stream[N] completes.
- std::map<int32, const Thunk*> last_blocking_thunk_for_stream;
std::map<const Thunk*, std::unique_ptr<se::Event>> thunk_to_finish_event;
for (Thunk* thunk : thunk_schedule_->TotalOrder()) {
- TF_RETURN_IF_ERROR(thunk->Initialize(*this));
+ TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor));
int32 stream_no =
thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction());
se::Stream* stream =
@@ -179,18 +176,10 @@ Status GpuExecutable::ExecuteThunks(
stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get());
}
- if (last_blocking_thunk_for_stream.count(stream_no)) {
- stream->ThenWaitFor(FindOrDie(thunk_to_finish_event,
- last_blocking_thunk_for_stream[stream_no])
- .get());
- last_blocking_thunk_for_stream.erase(stream_no);
- }
-
// If this thunk requests it, wait for all currently-executing thunks to
// finish. This is useful e.g. if the thunk is about to perform autotuning.
if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) {
TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone());
- last_blocking_thunk_for_stream.clear();
}
profiler.StartOperation();
@@ -198,22 +187,11 @@ Status GpuExecutable::ExecuteThunks(
<< thunk->hlo_instruction()->ToString() << " on stream "
<< stream_no;
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream));
- if (thunk_schedule_->Depended(thunk) || thunk->ShouldBlockFutureThunks()) {
+ if (thunk_schedule_->Depended(thunk)) {
auto finish_event = MakeUnique<se::Event>(main_stream->parent());
finish_event->Init();
stream->ThenRecordEvent(finish_event.get());
thunk_to_finish_event[thunk] = std::move(finish_event);
-
- if (thunk->ShouldBlockFutureThunks()) {
- // Set last_blocking_thunk_for_stream on all streams other than this one
- // so that all other streams will wait for this thunk to complete before
- // executing any events that occur later in the total order.
- for (int32 i = 0; i < sub_streams.size() + 1; ++i) {
- if (i != stream_no) {
- last_blocking_thunk_for_stream[i] = thunk;
- }
- }
- }
}
profiler.FinishOperation(thunk->hlo_instruction());
}
@@ -286,8 +264,8 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
se::StreamExecutor* executor = run_options->stream()->parent();
TF_ASSIGN_OR_RETURN(
auto buffer_allocations,
- buffer_allocations_builder.Build(*assignment_, executor->device_ordinal(),
- memory_allocator));
+ buffer_allocations_builder.Build(
+ assignment_.get(), executor->device_ordinal(), memory_allocator));
bool block_host_until_done =
!memory_allocator->AllowsAsynchronousDeallocation();
@@ -329,8 +307,7 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
buffers_in_result.insert(src_base);
return Status::OK();
}));
- TF_RETURN_IF_ERROR(
- buffer_allocations->TearDown(buffers_in_result, *assignment_));
+ TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result));
return std::move(shaped_buffer);
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
index f13727ca9b..7bb8df6581 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
@@ -44,8 +44,8 @@ GpuTransferManager::GpuTransferManager()
/*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout)
.getPointerSize(0 /* default address space */)) {}
-Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor,
- const Literal& literal) {
+Status GpuTransferManager::TransferLiteralToInfeed(
+ se::StreamExecutor* executor, const LiteralSlice& literal) {
const Shape& shape = literal.shape();
VLOG(2) << "Transferring literal to infeed with shape: "
<< ShapeUtil::HumanString(shape);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
index d040a99975..09f8227f50 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
@@ -37,7 +37,7 @@ class GpuTransferManager : public GenericTransferManager {
~GpuTransferManager() override {}
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
- const Literal& literal) override;
+ const LiteralSlice& literal) override;
Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
const void* source) override;
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
index 6436abc06c..e230d538cc 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
@@ -42,6 +42,15 @@ class HloScheduleTest : public HloTestBase {
.ConsumeValueOrDie();
}
+ std::unique_ptr<HloModule> CreateNewModule() {
+ HloModuleConfig config;
+ auto debug_options = GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_disable_multi_streaming(false);
+ config.set_debug_options(debug_options);
+ return MakeUnique<HloModule>("test_module", VersionedComputationHandle(),
+ config);
+ }
+
HloVec RemoveHlo(const HloVec& input,
const std::unordered_set<const HloInstruction*>& remove) {
HloVec result(input);
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index c5eb721185..5d5bef6b57 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -17,7 +17,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace gpu {
@@ -46,6 +48,15 @@ bool IsFusile(const HloInstruction& hlo) {
hlo.opcode() == HloOpcode::kTranspose;
}
+bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) {
+ if (constant->opcode() != HloOpcode::kConstant ||
+ !ShapeUtil::IsScalar(constant->shape())) {
+ return false;
+ }
+ auto type = constant->shape().element_type();
+ return type == F16 || type == F32 || type == F64;
+}
+
} // namespace
/*static*/ bool GpuInstructionFusion::IsExpensive(
@@ -66,34 +77,71 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
HloInstruction* producer = consumer->mutable_operand(operand_index);
// Check if we can use output fusion for (A @ B) * alpha
- if (producer->opcode() == HloOpcode::kDot) {
- if (consumer->opcode() == HloOpcode::kMultiply) {
- CHECK_EQ(consumer->operand_count(), 2);
- int64 other_operand_index = 1 - operand_index;
- const HloInstruction* alpha = consumer->operand(other_operand_index);
- if (alpha->opcode() == HloOpcode::kConstant &&
- ShapeUtil::IsScalar(alpha->shape())) {
+ if (consumer->operand_count() == 2 &&
+ (producer->opcode() == HloOpcode::kDot ||
+ (producer->opcode() == HloOpcode::kFusion &&
+ producer->fused_expression_root()->opcode() == HloOpcode::kDot))) {
+ int64 other_operand_index = 1 - operand_index;
+ const HloInstruction* alpha = consumer->operand(other_operand_index);
+ HloInstruction* op1 = nullptr;
+ HloInstruction* op2 = nullptr;
+ if (consumer->opcode() == HloOpcode::kFusion &&
+ consumer->fusion_kind() == HloInstruction::FusionKind::kLoop &&
+ Match(consumer->fused_expression_root(),
+ match::Op()
+ .WithOpcode(HloOpcode::kMultiply)
+ .WithOperand(0, match::Op(&op1))
+ .WithOperand(1, match::Op(&op2)))) {
+ CHECK(op1 != nullptr && op2 != nullptr);
+ // If 'consumer' is a fusion node, it should consist of a broadcast of a
+ // scalar constant fused into a multiply, but nothing more. So one operand
+ // should be a parameter, and the other should be a broadcast.
+ if (op1->opcode() != HloOpcode::kParameter) {
+ std::swap(op1, op2);
+ }
+ if (op1->opcode() != HloOpcode::kParameter ||
+ op2->opcode() != HloOpcode::kBroadcast) {
+ return false;
+ }
+ if (IsIEEEFloatingPointScalarConstant(alpha)) {
+ return true;
+ }
+ } else if (consumer->opcode() == HloOpcode::kMultiply) {
+ // Fuse if 'alpha' is a broadcast of a scalar constant.
+ if (alpha->opcode() == HloOpcode::kBroadcast &&
+ alpha->dimensions().empty() &&
+ IsIEEEFloatingPointScalarConstant(alpha->operand(0))) {
return true;
}
}
}
- // Only allow to fuse transpose into an output fusion.
+ // Only allow fusing transpose or broadcast into an output fusion that is
+ // implemented as a Gemm call.
if (consumer->opcode() == HloOpcode::kFusion &&
- consumer->fusion_kind() == HloInstruction::FusionKind::kOutput) {
- if (producer->opcode() != HloOpcode::kTranspose) {
- return false;
- }
- // Check that the transpose is the operand of a dot.
+ consumer->fusion_kind() == HloInstruction::FusionKind::kOutput &&
+ ImplementedAsGemm(*consumer)) {
auto producer_operand_index = consumer->operand_index(producer);
auto fused_parameter = consumer->fused_parameter(producer_operand_index);
const std::vector<HloInstruction*>& fused_parameter_users =
fused_parameter->users();
- return (fused_parameter_users.size() == 1 &&
- fused_parameter_users[0]->opcode() == HloOpcode::kDot);
+ if (fused_parameter_users.size() != 1) {
+ return false;
+ }
+ if (producer->opcode() == HloOpcode::kTranspose) {
+ // Check that the transpose is an operand of a dot.
+ return fused_parameter_users[0]->opcode() == HloOpcode::kDot;
+ }
+ if (producer->opcode() == HloOpcode::kBroadcast) {
+ // Check that the broadcast is a broadcast of a scalar constant into a
+ // multiply.
+ return producer->dimensions().empty() &&
+ IsIEEEFloatingPointScalarConstant(producer->operand(0)) &&
+ fused_parameter_users[0]->opcode() == HloOpcode::kMultiply;
+ }
}
- // Output fusion is not currently supported on GPUs.
+ // Other output fusions are not currently supported on GPUs.
if (producer->opcode() == HloOpcode::kFusion) {
return false;
}
@@ -134,7 +182,9 @@ HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
if (IsReductionToVector(*consumer)) {
return HloInstruction::FusionKind::kInput;
}
- if (producer->opcode() == HloOpcode::kDot) {
+ if (producer->opcode() == HloOpcode::kDot ||
+ (producer->opcode() == HloOpcode::kFusion &&
+ producer->fused_expression_root()->opcode() == HloOpcode::kDot)) {
return HloInstruction::FusionKind::kOutput;
}
if (HloOpcode::kFusion == consumer->opcode()) {
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index 6c9a805ad6..760e0e90f5 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -108,8 +108,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
+ ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1));
@@ -125,8 +125,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
+ ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
@@ -232,12 +232,13 @@ TEST_F(InstructionFusionTest, DotOutputFusion) {
auto module = tools::Parse(R"(
HloModule test_module
ENTRY OutputFusion {
- constant = f32[] constant(3)
+ alpha = f32[] constant(3)
+ broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={}
p0 = f32[4,3]{1,0} parameter(0)
p1 = f32[4,3]{1,0} parameter(1)
transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0}
- dot = f32[4,4]{1,0} dot(p0, transpose)
- ROOT mul = f32[4,4] multiply(constant, dot)
+ dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT mul = f32[4,4] multiply(dot, broadcast)
})")
.ValueOrDie();
@@ -247,10 +248,11 @@ TEST_F(InstructionFusionTest, DotOutputFusion) {
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
+ EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput);
EXPECT_THAT(
root->fused_expression_root(),
- op::Multiply(op::Parameter(),
- op::Dot(op::Parameter(), op::Transpose(op::Parameter()))));
+ op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())),
+ op::Broadcast(op::Parameter())));
}
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
@@ -309,5 +311,31 @@ TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
.ValueOrDie());
}
+TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY NoOutputFusion {
+ alpha = f32[] constant(3)
+ broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={}
+ p0 = f32[4,3]{1,0} parameter(0)
+ p1 = f32[3,4]{1,0} parameter(1)
+ dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ d = f32[4,4]{1,0} multiply(dot, dot)
+ ROOT mul = f32[4,4] multiply(d, broadcast)
+ })")
+ .ValueOrDie();
+
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop);
+ EXPECT_THAT(root->fused_expression_root(),
+ op::Multiply(op::Multiply(op::Parameter(), op::Parameter()),
+ op::Broadcast(op::Parameter())));
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 96199035b9..22e7150995 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -59,6 +59,25 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
!ShapeUtil::HasZeroElements(lhs_shape) &&
!ShapeUtil::HasZeroElements(rhs_shape);
}
+
+bool DotImplementedAsGemm(const HloInstruction& dot) {
+ CHECK_EQ(dot.opcode(), HloOpcode::kDot);
+ const Shape& lhs_shape = dot.operand(0)->shape();
+ const Shape& rhs_shape = dot.operand(1)->shape();
+
+ // If gemm can accept the operand shapes, use it rather than a custom
+ // kernel.
+ if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) {
+ // The size of the reduction dimension should match. The shape inference
+ // guarantees this invariant, so the check here is for programming
+ // errors.
+ const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
+ CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
+ rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
+ return true;
+ }
+ return false;
+}
} // namespace
bool ImplementedAsGemm(const HloInstruction& hlo) {
@@ -69,20 +88,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) {
// For certain types of Dot, we can call pre-canned BLAS gemm.
if (hlo.opcode() == HloOpcode::kDot) {
- const Shape& lhs_shape = hlo.operand(0)->shape();
- const Shape& rhs_shape = hlo.operand(1)->shape();
-
- // If gemm can accept the operand shapes, use it rather than a custom
- // kernel.
- if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
- // The size of the reduction dimension should match. The shape inference
- // guarantees this invariant, so the check here is for programming
- // errors.
- const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers();
- CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
- rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
- return true;
- }
+ return DotImplementedAsGemm(hlo);
}
if (hlo.opcode() == HloOpcode::kFusion &&
@@ -94,7 +100,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) {
dot = hlo.fused_expression_root()->operand(1);
}
if (dot->opcode() == HloOpcode::kDot) {
- return ImplementedAsGemm(*dot);
+ return DotImplementedAsGemm(*dot);
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 83d90296df..0d7ba4cf9a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2194,6 +2194,21 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
/*destination_buffer=*/GetAllocationSlice(*inst), inst);
}
+namespace {
+double GetScalarConstantAsDouble(const Literal& literal) {
+ switch (literal.shape().element_type()) {
+ case F16:
+ return static_cast<double>(literal.Get<Eigen::half>({}));
+ case F32:
+ return literal.Get<float>({});
+ case F64:
+ return literal.Get<double>({});
+ default:
+ LOG(FATAL) << "Unsupported type.";
+ }
+}
+} // namespace
+
std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
const HloInstruction* inst) {
if (inst->opcode() == HloOpcode::kDot) {
@@ -2218,6 +2233,17 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
if (dot->opcode() != HloOpcode::kDot) {
std::swap(dot, alpha);
}
+ if (alpha->opcode() == HloOpcode::kBroadcast) {
+ alpha = alpha->operand(0);
+ }
+ alpha = inst->operand(alpha->parameter_number());
+ // TODO(b/74185543): Remove the following if block once we support fusion
+ // with a non-constant as well. Then we will just always use the constant
+ // on the device.
+ if (alpha->opcode() == HloOpcode::kCopy) {
+ alpha = alpha->operand(0);
+ }
+
DCHECK(dot->opcode() == HloOpcode::kDot);
const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
@@ -2229,13 +2255,13 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
inst->operand(rhs_parameter->parameter_number());
return MakeUnique<GemmThunk>(
- GetAllocationSlice(*lhs), // The buffer assigned to LHS.
- GetAllocationSlice(*rhs), // The buffer assigned to RHS.
- GetAllocationSlice(*mul), // The output buffer.
- lhs->shape(), // The shape of LHS.
- rhs->shape(), // The shape of RHS.
- inst->shape(), // The shape of the output.
- alpha->literal().Get<double>({0}), // alpha.
+ GetAllocationSlice(*lhs), // The buffer assigned to LHS.
+ GetAllocationSlice(*rhs), // The buffer assigned to RHS.
+ GetAllocationSlice(*inst), // The output buffer.
+ lhs->shape(), // The shape of LHS.
+ rhs->shape(), // The shape of RHS.
+ inst->shape(), // The shape of the output.
+ GetScalarConstantAsDouble(alpha->literal()), // alpha.
inst);
}
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
index d376ef7a24..f56c1ce69f 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
@@ -35,26 +35,38 @@ KernelThunk::KernelThunk(
kernel_name_(kernel_name),
unroll_factor_(unroll_factor) {}
-tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) {
+Status KernelThunk::Initialize(const GpuExecutable& executable,
+ se::StreamExecutor* executor) {
tensorflow::mutex_lock lock(mutex_);
- if (loader_spec_) {
- // Already initialized by another thread.
- return tensorflow::Status::OK();
- }
+ if (!loader_spec_) {
+ loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size()));
+ tensorflow::StringPiece ptx = executable.ptx();
+ // Convert tensorflow::StringPiece to se::port::StringPiece because
+ // StreamExecutor uses the latter.
+ loader_spec_->AddCudaPtxInMemory(
+ se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_);
- loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size()));
- tensorflow::StringPiece ptx = executable.ptx();
- // Convert tensorflow::StringPiece to se::port::StringPiece because
- // StreamExecutor uses the latter.
- loader_spec_->AddCudaPtxInMemory(
- se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_);
+ if (!executable.cubin().empty()) {
+ loader_spec_->AddCudaCubinInMemory(
+ reinterpret_cast<const char*>(executable.cubin().data()),
+ kernel_name_);
+ }
+ }
- if (!executable.cubin().empty()) {
- loader_spec_->AddCudaCubinInMemory(
- reinterpret_cast<const char*>(executable.cubin().data()), kernel_name_);
+ // Load the kernel into the device if necessary.
+ //
+ // We could alternatively do this within ExecuteOnStream, but doing it here
+ // lets the time spent loading the kernel not count towards our execution
+ // profiles.
+ auto it = kernel_cache_.find(executor);
+ if (kernel_cache_.end() == it) {
+ it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first;
+ if (!executor->GetKernel(*loader_spec_, &it->second)) {
+ return InternalError("Unable to load kernel %s", kernel_name_.c_str());
+ }
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) {
@@ -62,21 +74,18 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) {
launch_dimensions_ = launch_dims;
}
-tensorflow::Status KernelThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream) {
// Load the kernel.
se::StreamExecutor* executor = stream->parent();
LaunchDimensions launch_dimensions;
const se::KernelBase* kernel = nullptr;
+
{
tensorflow::mutex_lock lock(mutex_);
auto it = kernel_cache_.find(executor);
- if (kernel_cache_.end() == it) {
- it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first;
- if (!executor->GetKernel(*loader_spec_, &it->second)) {
- return InternalError("Unable to load kernel %s", kernel_name_.c_str());
- }
- }
+ CHECK(it != kernel_cache_.end())
+ << "Initialize() not called for StreamExecutor " << executor;
launch_dimensions = launch_dimensions_;
kernel = &it->second;
}
@@ -97,7 +106,7 @@ tensorflow::Status KernelThunk::ExecuteOnStream(
*kernel_args)) {
return InternalError("Unable to launch kernel %s", kernel_name_.c_str());
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
index b556befe66..7def27e189 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
@@ -57,11 +57,12 @@ class KernelThunk : public Thunk {
int unroll_factor() const { return unroll_factor_; }
void SetLaunchDimensions(const LaunchDimensions& launch_dims);
- tensorflow::Status Initialize(const GpuExecutable& executable) override;
+ Status Initialize(const GpuExecutable& executable,
+ se::StreamExecutor* executor) override;
// Executes the kernel for the thunk on "stream", which must be non-null.
- tensorflow::Status ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) override;
+ Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream) override;
private:
// Buffers passed to the kernel as arguments.
@@ -83,7 +84,8 @@ class KernelThunk : public Thunk {
mutable tensorflow::mutex mutex_;
std::unique_ptr<se::MultiKernelLoaderSpec> loader_spec_ GUARDED_BY(mutex_);
- // Loaded kernels for each `StreamExecutor`
+ // Loaded kernels for each `StreamExecutor`. Requires pointer stability of
+ // values.
std::unordered_map<se::StreamExecutor*, se::KernelBase> kernel_cache_
GUARDED_BY(mutex_);
};
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
index d70cb07c57..917c576823 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
@@ -77,8 +77,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path,
// Since CUDA 9.0, all GPU versions are included in a single file
const char* unified_libdevice_filename = "libdevice.10.bc";
std::vector<string> unified_libdevice_files;
- const tensorflow::Status status =
- tensorflow::Env::Default()->GetMatchingPaths(
+ const Status status = tensorflow::Env::Default()->GetMatchingPaths(
tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename),
&unified_libdevice_files);
if (status.ok() && unified_libdevice_files.size() == 1) {
@@ -311,11 +310,11 @@ bool CouldNeedLibdevice(const llvm::Module& module) {
}
// Links libdevice into the given module if the module needs libdevice.
-tensorflow::Status LinkLibdeviceIfNecessary(
- llvm::Module* module, std::pair<int, int> compute_capability,
- const string& libdevice_dir_path) {
+Status LinkLibdeviceIfNecessary(llvm::Module* module,
+ std::pair<int, int> compute_capability,
+ const string& libdevice_dir_path) {
if (!CouldNeedLibdevice(*module)) {
- return tensorflow::Status::OK();
+ return Status::OK();
}
llvm::Linker linker(*module);
@@ -336,7 +335,7 @@ tensorflow::Status LinkLibdeviceIfNecessary(
return tensorflow::errors::Internal(tensorflow::strings::StrCat(
"Error linking libdevice from ", libdevice_path));
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
StatusOr<string> CompileModuleToPtx(llvm::Module* module,
diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
index c8510808f1..b50f5b5a90 100644
--- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
@@ -24,20 +24,20 @@ SequentialThunk::SequentialThunk(std::vector<std::unique_ptr<Thunk>>&& thunks,
const HloInstruction* hlo)
: Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {}
-tensorflow::Status SequentialThunk::Initialize(
- const GpuExecutable& executable) {
+Status SequentialThunk::Initialize(const GpuExecutable& executable,
+ se::StreamExecutor* executor) {
for (auto& thunk : thunks_) {
- TF_RETURN_IF_ERROR(thunk->Initialize(executable));
+ TF_RETURN_IF_ERROR(thunk->Initialize(executable, executor));
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status SequentialThunk::ExecuteOnStream(
+Status SequentialThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream) {
for (const auto& thunk : thunks_) {
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream));
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
index df17b8d67b..3537110bb5 100644
--- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
@@ -38,9 +38,10 @@ class SequentialThunk : public Thunk {
const std::vector<std::unique_ptr<Thunk>>& thunks() const { return thunks_; }
- tensorflow::Status Initialize(const GpuExecutable& executable) override;
- tensorflow::Status ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) override;
+ Status Initialize(const GpuExecutable& executable,
+ se::StreamExecutor* executor) override;
+ Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream) override;
private:
// The list of sub-thunks.
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index b42767dfd5..696fa7e019 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -28,6 +28,15 @@ namespace gpu {
class StreamAssignmentTest : public HloTestBase {
protected:
+ std::unique_ptr<HloModule> CreateNewModule() {
+ HloModuleConfig config;
+ auto debug_options = GetDebugOptionsForTest();
+ debug_options.set_xla_gpu_disable_multi_streaming(false);
+ config.set_debug_options(debug_options);
+ return MakeUnique<HloModule>("test_module", VersionedComputationHandle(),
+ config);
+ }
+
// Pre-canned shapes.
Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});
};
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index a0c785ed91..931c0bffab 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -70,11 +70,14 @@ class Thunk {
Kind kind() const { return kind_; }
const HloInstruction* hlo_instruction() const { return hlo_instruction_; }
- // Prepares for executing the thunk. This method is called only once over
- // Thunk's lifetime. For example, KernelThunk::Initialize loads the PTX of a
- // kernel, which is the same in every execution.
- virtual tensorflow::Status Initialize(const GpuExecutable& executable) {
- return tensorflow::Status::OK();
+ // Prepares the thunk for execution on the given StreamExecutor.
+ //
+ // This may be called multiple times. Its main purpose is to give us a chance
+ // to do initialization outside of ExecuteOnStream() so that the
+ // time spent initializing doesn't count towards our execution profile.
+ virtual Status Initialize(const GpuExecutable& /*executable*/,
+ se::StreamExecutor* /*executor*/) {
+ return Status::OK();
}
// Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream)
@@ -89,21 +92,13 @@ class Thunk {
return false;
}
- // Indicates whether thunks scheduled after this one should wait for this one
- // to complete before running. For example, a convolution thunk creates a
- // scratch allocator, then kicks off a convolution in cudnn via the stream
- // executor. When the stream executor call returns, the scratch allocator goes
- // out of scope, and the scratch memory is deallocated. In this case, the
- // convolution thunk needs to return true so that future thunks wait for the
- // convolution thunk to avoid reusing the deallocated memory until the
- // convolution thunk is done with it.
- virtual bool ShouldBlockFutureThunks() { return false; }
-
// Execute the kernel for the thunk on the given stream. This method must be
// called after Initialize and can be called multiple times over Thunk's
// lifetime. Stream argument must be non-null.
- virtual tensorflow::Status ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) = 0;
+ //
+ // Precondition: Initialize(stream->parent()) has been called.
+ virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream) = 0;
private:
Kind kind_;
diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
index ecb54857cc..97cb04c38f 100644
--- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
@@ -20,8 +20,8 @@ limitations under the License.
namespace xla {
namespace gpu {
-tensorflow::Status TupleThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream) {
std::vector<void*> tuple_element_buffer_addresses;
for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) {
tuple_element_buffer_addresses.push_back(
@@ -40,7 +40,7 @@ tensorflow::Status TupleThunk::ExecuteOnStream(
tuple_element_buffer_addresses.data(), dest_buffer_address.opaque(),
sizeof(void*) * tuple_element_buffer_addresses.size());
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
index 8b459c29a1..951f809b51 100644
--- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
@@ -45,8 +45,8 @@ class TupleThunk : public Thunk {
TupleThunk(const TupleThunk&) = delete;
TupleThunk& operator=(const TupleThunk&) = delete;
- tensorflow::Status ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) override;
+ Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream) override;
private:
const std::vector<BufferAllocation::Slice> tuple_element_buffers_;
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
index a9f3d619a3..30b9640c4c 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
@@ -34,9 +34,11 @@ WhileThunk::WhileThunk(
body_thunk_sequence_(
MakeUnique<SequentialThunk>(std::move(*body_thunk_sequence), hlo)) {}
-Status WhileThunk::Initialize(const GpuExecutable& executable) {
- TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(executable));
- TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable));
+Status WhileThunk::Initialize(const GpuExecutable& executable,
+ se::StreamExecutor* executor) {
+ TF_RETURN_IF_ERROR(
+ condition_thunk_sequence_->Initialize(executable, executor));
+ TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h
index e589ca78a7..22176685a9 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h
@@ -45,7 +45,8 @@ class WhileThunk : public Thunk {
WhileThunk(const WhileThunk&) = delete;
WhileThunk& operator=(const WhileThunk&) = delete;
- Status Initialize(const GpuExecutable& executable) override;
+ Status Initialize(const GpuExecutable& executable,
+ se::StreamExecutor* executor) override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream) override;
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
index e6caec8625..ad55728c45 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
@@ -144,7 +144,7 @@ class ExprTree {
TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first),
tagged_instructions));
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
private:
@@ -169,7 +169,7 @@ class MatcherBase {
// Attempts to match each ExprTree in 'expr_trees_'.
// Returns OK on the first successful match, error status otherwise.
- virtual tensorflow::Status Run() {
+ virtual Status Run() {
Status status;
for (const ExprTree& expr_tree : expr_trees_) {
status = MatchExprTree(expr_tree);
@@ -201,7 +201,7 @@ class MatcherBase {
} else if (type == S64) {
*const_value = literal.GetFirstElement<int64>();
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
StatusOr<const HloInstruction*> GetTaggedInstruction(
@@ -315,7 +315,7 @@ class WhileConditionComputationMatcher : public MatcherBase {
gte_fusion_param0->name().c_str());
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
const HloComputation* computation_;
@@ -379,7 +379,7 @@ class WhileInitOperandMatcher : public MatcherBase {
GetTaggedInstruction("loop_start", tagged_instructions));
TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_));
- return tensorflow::Status::OK();
+ return Status::OK();
}
const HloInstruction* while_hlo_;
@@ -477,7 +477,7 @@ class WhileBodyComputationMatcher : public MatcherBase {
}
}
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
const HloComputation* computation_;
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 3dd4c4a079..9a07ee3683 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -32,7 +32,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
const SequentialHloOrdering::HloModuleSequence& module_sequence,
const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_fn, const Options& options) {
+ const BufferValue::SizeFunction& size_fn, const Options& options) {
HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence);
const HloComputation* entry_computation = module.entry_computation();
const std::vector<const HloInstruction*>& instruction_sequence =
@@ -47,7 +47,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_fn, const Options& options) {
+ const BufferValue::SizeFunction& size_fn, const Options& options) {
HeapSimulator heap(std::move(algorithm), size_fn, options,
/*module_sequence=*/nullptr);
TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
@@ -73,11 +73,11 @@ Status HeapSimulator::RunComputation(
// 'used_buffers' is the reverse map - it tracks which buffers were used by an
// instruction, so that we can remove the instructions from a buffer's live
// set after they are visited.
- FlatMap<const LogicalBuffer*, FlatSet<const HloInstruction*>> live_buffers;
- FlatMap<const HloInstruction*, FlatSet<const LogicalBuffer*>> used_buffers;
+ FlatMap<const BufferValue*, FlatSet<const HloInstruction*>> live_buffers;
+ FlatMap<const HloInstruction*, FlatSet<const BufferValue*>> used_buffers;
auto add_user_to_buffer = [this, &live_buffers, &used_buffers](
const HloInstruction* user,
- const LogicalBuffer* buffer) {
+ const BufferValue* buffer) {
if (!IgnoreBuffer(buffer)) {
VLOG(4) << " Adding user " << user->name() << " to buffer "
<< buffer->ToString();
@@ -96,7 +96,7 @@ Status HeapSimulator::RunComputation(
const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet();
for (const HloInstruction* user : instruction->users()) {
if (user->opcode() != HloOpcode::kGetTupleElement) {
- for (const LogicalBuffer* buffer : buffer_set) {
+ for (const BufferValue* buffer : buffer_set) {
add_user_to_buffer(user, buffer);
}
} else {
@@ -104,12 +104,12 @@ Status HeapSimulator::RunComputation(
// alive. It only needs the buffers that relate to the element its
// extracting, and the tuple it's extracting from, but not the buffers
// for the other elements.
- for (const LogicalBuffer* buffer : points_to.element({})) {
+ for (const BufferValue* buffer : points_to.element({})) {
add_user_to_buffer(user, buffer);
}
const PointsToSet& gte_points_to =
points_to_analysis.GetPointsToSet(user);
- for (const LogicalBuffer* buffer : gte_points_to.CreateFlattenedSet()) {
+ for (const BufferValue* buffer : gte_points_to.CreateFlattenedSet()) {
add_user_to_buffer(user, buffer);
}
}
@@ -117,24 +117,25 @@ Status HeapSimulator::RunComputation(
}
const HloInstruction* root = computation.root_instruction();
- auto output_source_buffers =
- points_to_analysis.GetPointsToSet(root).CreateFlattenedSet();
+ BufferValueCompactPointerSet output_source_buffers =
+ ToBufferValueCompactPointerSet(
+ points_to_analysis.GetPointsToSet(root).CreateFlattenedSet());
- std::vector<const LogicalBuffer*> dead_buffers_to_free;
- std::vector<const LogicalBuffer*> operand_buffers_to_free;
+ std::vector<const BufferValue*> dead_buffers_to_free;
+ std::vector<const BufferValue*> operand_buffers_to_free;
for (const HloInstruction* instruction : instruction_sequence) {
const TuplePointsToAnalysis::BufferDefinitionVector&
buffers_defined_by_instruction =
points_to_analysis.GetBuffersDefinedByInstruction(instruction);
VLOG(3) << "Instruction: " << instruction->ToString();
- for (const LogicalBuffer* buffer : buffers_defined_by_instruction) {
+ for (const BufferValue* buffer : buffers_defined_by_instruction) {
VLOG(4) << " Defines: " << buffer->ToString()
<< (IgnoreBuffer(buffer) ? " (Ignored)" : "");
}
dead_buffers_to_free.clear();
- for (const LogicalBuffer* buffer : buffers_defined_by_instruction) {
+ for (const BufferValue* buffer : buffers_defined_by_instruction) {
if (IgnoreBuffer(buffer)) {
continue;
}
@@ -161,7 +162,7 @@ Status HeapSimulator::RunComputation(
// have no instructions left to visit are moved from live_buffers to
// operand_buffers_to_free.
operand_buffers_to_free.clear();
- for (const LogicalBuffer* operand_buffer : used_buffers[instruction]) {
+ for (const BufferValue* operand_buffer : used_buffers[instruction]) {
if (IgnoreBuffer(operand_buffer)) {
continue;
}
@@ -177,7 +178,7 @@ Status HeapSimulator::RunComputation(
}
// Sort to get a deterministic iteration order.
std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(),
- [](const LogicalBuffer* x, const LogicalBuffer* y) {
+ [](const BufferValue* x, const BufferValue* y) {
return x->id() < y->id();
});
@@ -188,7 +189,7 @@ Status HeapSimulator::RunComputation(
//
// INVARIANT: Either Alloc or ShareBuffer will be called for each buffer
// that we should assign.
- for (const LogicalBuffer* buffer : buffers_defined_by_instruction) {
+ for (const BufferValue* buffer : buffers_defined_by_instruction) {
if (IgnoreBuffer(buffer)) {
continue;
}
@@ -199,7 +200,7 @@ Status HeapSimulator::RunComputation(
// we must be the last user of the buffer.
bool shared = false;
if (options_.may_reuse_operand_buffers) {
- for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) {
+ for (const BufferValue* operand_buffer : operand_buffers_to_free) {
if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) &&
buffer->instruction()->opcode() != HloOpcode::kCopy &&
CanShareOperandBufferWithUser(
@@ -248,11 +249,11 @@ Status HeapSimulator::RunComputation(
// Free buffers that are no longer live. This is the earliest point that we
// can de-allocate; right after the last use of the buffer.
- for (const LogicalBuffer* buffer : dead_buffers_to_free) {
+ for (const BufferValue* buffer : dead_buffers_to_free) {
VLOG(3) << " Freeing dead: " << buffer->ToString();
Free(buffer, instruction);
}
- for (const LogicalBuffer* buffer : operand_buffers_to_free) {
+ for (const BufferValue* buffer : operand_buffers_to_free) {
VLOG(3) << " Freeing operand: " << buffer->ToString();
Free(buffer, instruction);
}
@@ -261,10 +262,10 @@ Status HeapSimulator::RunComputation(
// Any remaining live buffers must be entry parameters or output source
// buffers, which had a nullptr sentry added. Free them now, in a
// deterministic order.
- std::vector<const LogicalBuffer*> to_free;
+ std::vector<const BufferValue*> to_free;
to_free.reserve(live_buffers.size());
for (const auto& buffer_pending : live_buffers) {
- const LogicalBuffer* buffer = buffer_pending.first;
+ const BufferValue* buffer = buffer_pending.first;
const FlatSet<const HloInstruction*>& pending = buffer_pending.second;
CHECK_EQ(pending.size(), 1) << *buffer;
CHECK(*pending.begin() == nullptr) << *buffer;
@@ -272,10 +273,10 @@ Status HeapSimulator::RunComputation(
}
std::sort(to_free.begin(), to_free.end(),
- [](const LogicalBuffer* x, const LogicalBuffer* y) {
+ [](const BufferValue* x, const BufferValue* y) {
return x->id() < y->id();
});
- for (const LogicalBuffer* buffer : to_free) {
+ for (const BufferValue* buffer : to_free) {
VLOG(3) << "Freeing pending: " << buffer->ToString();
Free(buffer, root);
}
@@ -285,7 +286,7 @@ Status HeapSimulator::RunComputation(
HeapSimulator::HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
- const LogicalBuffer::SizeFunction& size_fn, const Options& options,
+ const BufferValue::SizeFunction& size_fn, const Options& options,
const SequentialHloOrdering::HloModuleSequence* module_sequence)
: no_fragmentation_stats_(MakeUnique<NoFragmentationStatsHeap>()),
algorithm_(std::move(algorithm)),
@@ -297,7 +298,7 @@ HeapSimulator::HeapSimulator(
HeapSimulator::~HeapSimulator() {}
-bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const {
+bool HeapSimulator::IgnoreBuffer(const BufferValue* buffer) const {
// Buffers for constants are ignored unless the alloc_constants option is
// set. Also ignore buffers that we're not meant to assign.
//
@@ -311,7 +312,7 @@ bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const {
}
// Alloc always calls the underlying heap algorithm.
-void HeapSimulator::Alloc(const LogicalBuffer* buffer,
+void HeapSimulator::Alloc(const BufferValue* buffer,
const HloInstruction* instruction) {
CHECK(allocated_buffers_.count(buffer) == 0)
<< "Alloc called on allocated buffer: " << *buffer;
@@ -331,7 +332,7 @@ void HeapSimulator::Alloc(const LogicalBuffer* buffer,
// buffers whose group liveness has expired. Shared group liveness is tracked
// by maintaining a refcount; the Free call on the last buffer in the group
// causes Free to be called on the underlying algorithm.
-void HeapSimulator::Free(const LogicalBuffer* buffer,
+void HeapSimulator::Free(const BufferValue* buffer,
const HloInstruction* instruction) {
auto shared_it = shared_buffers_.find(buffer);
if (shared_it != shared_buffers_.end()) {
@@ -362,8 +363,8 @@ void HeapSimulator::Free(const LogicalBuffer* buffer,
// The 'buffer' must be a non-allocated, non-freed buffer, just like in calls to
// Alloc. The 'shared' buffer must be a previously allocated or shared buffer.
// Both 'buffer' and 'shared' will be associated with the same SharedGroup.
-void HeapSimulator::ShareBuffer(const LogicalBuffer* buffer,
- const LogicalBuffer* shared,
+void HeapSimulator::ShareBuffer(const BufferValue* buffer,
+ const BufferValue* shared,
const HloInstruction* instruction) {
CHECK_LE(size_fn_(*buffer), size_fn_(*shared))
<< "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared;
@@ -374,7 +375,7 @@ void HeapSimulator::ShareBuffer(const LogicalBuffer* buffer,
CHECK(freed_buffers_.count(shared) == 0)
<< "ShareBuffer called on freed shared buffer: " << *shared;
- const LogicalBuffer* canonical = nullptr;
+ const BufferValue* canonical = nullptr;
auto shared_it = shared_buffers_.find(shared);
if (shared_it != shared_buffers_.end()) {
// The 'shared' buffer already has a group; it might be the canonical, but
@@ -408,7 +409,7 @@ HeapSimulator::Result HeapSimulator::Finish() {
// collecting statistics, e.g. NoFragmentationStatsHeap.
if (!result.chunk_map.empty()) {
for (const auto& share_pair : shared_buffers_) {
- const LogicalBuffer* buffer = share_pair.first;
+ const BufferValue* buffer = share_pair.first;
std::shared_ptr<SharedGroup> group = share_pair.second;
if (buffer != group->canonical) {
// The canonical must already exist in the chunk_map, since we called
@@ -437,9 +438,9 @@ HeapSimulator::Result HeapSimulator::Finish() {
}
void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
- const LogicalBuffer* buffer,
+ const BufferValue* buffer,
const HloInstruction* instruction,
- const LogicalBuffer* share_with_canonical) {
+ const BufferValue* share_with_canonical) {
HeapSimulatorTrace::Event* event = debug_trace_.add_events();
event->set_kind(kind);
event->set_buffer_id(buffer->id());
@@ -453,14 +454,14 @@ void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
}
}
-void NoFragmentationStatsHeap::Alloc(const LogicalBuffer* buffer, int64 size) {
+void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) {
current_heap_size_ += size;
if (current_heap_size_ > max_heap_size_) {
max_heap_size_ = current_heap_size_;
}
}
-void NoFragmentationStatsHeap::Free(const LogicalBuffer* buffer, int64 size) {
+void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) {
current_heap_size_ -= size;
}
@@ -472,12 +473,12 @@ HeapSimulator::Result NoFragmentationStatsHeap::Finish() {
return result;
}
-void DecreasingSizeRunsHeap::Alloc(const LogicalBuffer* buffer, int64 size) {
+void DecreasingSizeRunsHeap::Alloc(const BufferValue* buffer, int64 size) {
SetMode(kAlloc);
run_.emplace_back(Op{buffer, size});
}
-void DecreasingSizeRunsHeap::Free(const LogicalBuffer* buffer, int64 size) {
+void DecreasingSizeRunsHeap::Free(const BufferValue* buffer, int64 size) {
CHECK(mode_ != kInit) << "Free called on empty heap: " << *buffer;
SetMode(kFree);
run_.emplace_back(Op{buffer, size});
@@ -518,7 +519,7 @@ void DecreasingSizeRunsHeap::CallAndDrainRun() {
run_.clear();
}
-void LazyBestFitHeap::Alloc(const LogicalBuffer* buffer, int64 size) {
+void LazyBestFitHeap::Alloc(const BufferValue* buffer, int64 size) {
// Degenerate case: 0-sized buffers are always allocated at offset 0.
if (size == 0) {
result_.chunk_map.emplace(buffer, Chunk{0, 0});
@@ -586,7 +587,7 @@ void LazyBestFitHeap::Alloc(const LogicalBuffer* buffer, int64 size) {
result_.chunk_map.emplace(buffer, Chunk{kLazyAllocOffset, size});
}
-void LazyBestFitHeap::Free(const LogicalBuffer* buffer, int64 size) {
+void LazyBestFitHeap::Free(const BufferValue* buffer, int64 size) {
auto alloc_it = result_.chunk_map.find(buffer);
CHECK(alloc_it != result_.chunk_map.end())
<< "Free called on non-allocated buffer: " << *buffer;
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index 636f19dd39..8b2b43a37a 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -21,11 +21,12 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/compiler/xla/service/buffer_value.h"
+#include "tensorflow/compiler/xla/service/buffer_value_containers.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
-#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -43,7 +44,7 @@ class HeapAlgorithm;
// don't need to return the assignment of buffer offsets until the very end.
class HeapSimulator {
public:
- // Chunk represents a contiguous piece of memory. Each LogicalBuffer will be
+ // Chunk represents a contiguous piece of memory. Each BufferValue will be
// associated with a chunk in the assignment result.
struct Chunk {
int64 offset;
@@ -55,7 +56,7 @@ class HeapSimulator {
// Result represents the result of the heap simulation.
struct Result {
// The assignment of buffers to chunks.
- tensorflow::gtl::FlatMap<const LogicalBuffer*, Chunk> chunk_map;
+ tensorflow::gtl::FlatMap<const BufferValue*, Chunk> chunk_map;
// The total size in bytes of the heap, containing all assigned chunks.
int64 heap_size = 0;
@@ -81,7 +82,7 @@ class HeapSimulator {
bool alloc_constants;
// If 'buffers_to_assign' is provided, only those buffers are assigned
// offsets, otherwise all buffers defined by the instructions are assigned.
- const tensorflow::gtl::FlatSet<const LogicalBuffer*>* buffers_to_assign;
+ const BufferValueFlatSet* buffers_to_assign;
};
// Run the heap simulation with the given algorithm, assuming the given
@@ -97,7 +98,7 @@ class HeapSimulator {
std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
const SequentialHloOrdering::HloModuleSequence& module_sequence,
const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_fn,
+ const BufferValue::SizeFunction& size_fn,
const Options& options = Options());
// Same as above, but runs on a single computation. The 'instruction_sequence'
@@ -109,7 +110,7 @@ class HeapSimulator {
const HloComputation& computation,
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_fn,
+ const BufferValue::SizeFunction& size_fn,
const Options& options = Options());
private:
@@ -118,7 +119,7 @@ class HeapSimulator {
// be run recursively. I.e. the simulation is run over the whole module.
HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
- const LogicalBuffer::SizeFunction& size_fn, const Options& options,
+ const BufferValue::SizeFunction& size_fn, const Options& options,
const SequentialHloOrdering::HloModuleSequence* module_sequence);
~HeapSimulator();
@@ -127,21 +128,21 @@ class HeapSimulator {
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis);
- bool IgnoreBuffer(const LogicalBuffer* buffer) const;
- void Alloc(const LogicalBuffer* buffer, const HloInstruction* instruction);
- void Free(const LogicalBuffer* buffer, const HloInstruction* instruction);
- void ShareBuffer(const LogicalBuffer* buffer, const LogicalBuffer* shared,
+ bool IgnoreBuffer(const BufferValue* buffer) const;
+ void Alloc(const BufferValue* buffer, const HloInstruction* instruction);
+ void Free(const BufferValue* buffer, const HloInstruction* instruction);
+ void ShareBuffer(const BufferValue* buffer, const BufferValue* shared,
const HloInstruction* instruction);
Result Finish();
void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
- const LogicalBuffer* buffer,
+ const BufferValue* buffer,
const HloInstruction* instruction,
- const LogicalBuffer* shared_with_canonical);
+ const BufferValue* shared_with_canonical);
const std::unique_ptr<HeapAlgorithm> no_fragmentation_stats_;
const std::unique_ptr<HeapAlgorithm> algorithm_;
- const LogicalBuffer::SizeFunction size_fn_;
+ const BufferValue::SizeFunction size_fn_;
const Options options_;
const SequentialHloOrdering::HloModuleSequence* module_sequence_;
@@ -160,15 +161,15 @@ class HeapSimulator {
// The shared_buffers_ map associates each shared buffer (including the
// canonical) to its SharedGroup control block.
struct SharedGroup {
- const LogicalBuffer* canonical = nullptr;
+ const BufferValue* canonical = nullptr;
int64 refcount = 0;
};
- tensorflow::gtl::FlatMap<const LogicalBuffer*, std::shared_ptr<SharedGroup>>
+ tensorflow::gtl::FlatMap<const BufferValue*, std::shared_ptr<SharedGroup>>
shared_buffers_;
// Hold some sets for error-checking the sequence of Alloc and Free calls.
- tensorflow::gtl::FlatSet<const LogicalBuffer*> allocated_buffers_;
- tensorflow::gtl::FlatSet<const LogicalBuffer*> freed_buffers_;
+ tensorflow::gtl::FlatSet<const BufferValue*> allocated_buffers_;
+ tensorflow::gtl::FlatSet<const BufferValue*> freed_buffers_;
// Debugging information filled in while the heap simulator runs.
HeapSimulatorTrace debug_trace_;
@@ -186,10 +187,10 @@ class HeapAlgorithm {
virtual ~HeapAlgorithm() = default;
// Alloc allocates a buffer of 'size' bytes.
- virtual void Alloc(const LogicalBuffer* buffer, int64 size) = 0;
+ virtual void Alloc(const BufferValue* buffer, int64 size) = 0;
// Free de-allocates a previously allocated buffer.
- virtual void Free(const LogicalBuffer* buffer, int64 size) = 0;
+ virtual void Free(const BufferValue* buffer, int64 size) = 0;
// Finish collects the buffer offset assignment results. Free may only be
// called once, after the Alloc and Free calls.
@@ -205,8 +206,8 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
NoFragmentationStatsHeap() = default;
~NoFragmentationStatsHeap() override = default;
- void Alloc(const LogicalBuffer* buffer, int64 size) override;
- void Free(const LogicalBuffer* buffer, int64 size) override;
+ void Alloc(const BufferValue* buffer, int64 size) override;
+ void Free(const BufferValue* buffer, int64 size) override;
Result Finish() override;
private:
@@ -223,14 +224,14 @@ class DecreasingSizeRunsHeap : public HeapAlgorithm {
: algorithm_(std::move(algorithm)) {}
~DecreasingSizeRunsHeap() override {}
- void Alloc(const LogicalBuffer* buffer, int64 size) override;
- void Free(const LogicalBuffer* buffer, int64 size) override;
+ void Alloc(const BufferValue* buffer, int64 size) override;
+ void Free(const BufferValue* buffer, int64 size) override;
Result Finish() override;
private:
// A single Alloc or Free operation that we've buffered in run_.
struct Op {
- const LogicalBuffer* buffer;
+ const BufferValue* buffer;
int64 size;
};
@@ -266,8 +267,8 @@ class LazyBestFitHeap : public HeapAlgorithm {
LazyBestFitHeap(int64 alignment) : alignment_(alignment) {}
~LazyBestFitHeap() override {}
- void Alloc(const LogicalBuffer* buffer, int64 size) override;
- void Free(const LogicalBuffer* buffer, int64 size) override;
+ void Alloc(const BufferValue* buffer, int64 size) override;
+ void Free(const BufferValue* buffer, int64 size) override;
Result Finish() override;
private:
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index fd56a603bb..6271652412 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
-#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/hlo_value.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@@ -39,7 +39,7 @@ const char kFree[] = "Free";
const char kFinish[] = "Finish";
// CallSequence records a sequence of Alloc/Free/Finish calls.
-using CallSequence = std::vector<std::pair<string, const LogicalBuffer*>>;
+using CallSequence = std::vector<std::pair<string, const BufferValue*>>;
// HeapCallRecorder is a dummy heap algorithm that simply records its calls.
class HeapCallRecorder : public HeapAlgorithm {
@@ -47,7 +47,7 @@ class HeapCallRecorder : public HeapAlgorithm {
explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {}
~HeapCallRecorder() override {}
- void Alloc(const LogicalBuffer* buffer, int64 size) override {
+ void Alloc(const BufferValue* buffer, int64 size) override {
calls_->emplace_back(kAlloc, buffer);
// Instead of assigning a real offset, we set the cardinality of the Alloc
// call. This isn't a valid assignment, but allows us to easily test for
@@ -55,7 +55,7 @@ class HeapCallRecorder : public HeapAlgorithm {
const int64 offset = result_.chunk_map.size();
result_.chunk_map.emplace(buffer, Chunk{offset, size});
}
- void Free(const LogicalBuffer* buffer, int64 size) override {
+ void Free(const BufferValue* buffer, int64 size) override {
calls_->emplace_back(kFree, buffer);
}
Result Finish() override {
@@ -118,7 +118,7 @@ class HeapSimulatorTracker {
// Hack the size_fn so that it returns a decreasing value as we step through
// the sequence. This lets us ensure the Alloc calls are in the sequence
- // order. The Free calls are sorted by LogicalBuffer.id, which is at least
+ // order. The Free calls are sorted by BufferValue.id, which is at least
// deterministic.
auto size_fn = [&reverse_position](const BufferValue& buffer) {
return reverse_position[buffer.instruction()];
@@ -133,8 +133,8 @@ class HeapSimulatorTracker {
HloModule* module() { return module_.get(); }
// Returns the buffer defined at the given instruction and index.
- const LogicalBuffer* BufferAt(const HloInstruction* instruction,
- const ShapeIndex& index) const {
+ const BufferValue* BufferAt(const HloInstruction* instruction,
+ const ShapeIndex& index) const {
return points_to_analysis_->GetBufferDefinedAt(instruction, index)
.ConsumeValueOrDie();
}
@@ -150,8 +150,8 @@ class HeapSimulatorTracker {
const ShapeIndex& index_a,
const HloInstruction* instruction_b,
const ShapeIndex& index_b) {
- const LogicalBuffer* a = BufferAt(instruction_a, index_a);
- const LogicalBuffer* b = BufferAt(instruction_b, index_b);
+ const BufferValue* a = BufferAt(instruction_a, index_a);
+ const BufferValue* b = BufferAt(instruction_b, index_b);
EXPECT_EQ(result_.chunk_map[a].offset, result_.chunk_map[b].offset)
<< *a << ", " << *b;
}
@@ -525,7 +525,7 @@ TEST_F(HeapSimulatorTest, WholeModule) {
// Now the final cond less-than buffer is allocated.
{kAlloc, tracker.BufferAt(cond_lt, {})},
- // The order of the remaining Free calls is based on the LogicalBuffer.id,
+ // The order of the remaining Free calls is based on the BufferValue.id,
// which is deterministic, but not obvious.
{kFree, tracker.BufferAt(param, {})},
{kFree, tracker.BufferAt(param, {0})},
@@ -547,40 +547,40 @@ TEST_F(HeapSimulatorTest, WholeModule) {
class HeapAlgorithmTestBase : public ::testing::Test {
protected:
HeapAlgorithmTestBase() : builder_("heap_simulator_test") {
- buffer_a_ = DummyLogicalBuffer();
- buffer_b_ = DummyLogicalBuffer();
- buffer_c_ = DummyLogicalBuffer();
- buffer_d_ = DummyLogicalBuffer();
- buffer_e_ = DummyLogicalBuffer();
- buffer_f_ = DummyLogicalBuffer();
- buffer_g_ = DummyLogicalBuffer();
- buffer_h_ = DummyLogicalBuffer();
- buffer_i_ = DummyLogicalBuffer();
+ buffer_a_ = DummyBufferValue();
+ buffer_b_ = DummyBufferValue();
+ buffer_c_ = DummyBufferValue();
+ buffer_d_ = DummyBufferValue();
+ buffer_e_ = DummyBufferValue();
+ buffer_f_ = DummyBufferValue();
+ buffer_g_ = DummyBufferValue();
+ buffer_h_ = DummyBufferValue();
+ buffer_i_ = DummyBufferValue();
}
~HeapAlgorithmTestBase() override {}
- const LogicalBuffer* buffer_a_;
- const LogicalBuffer* buffer_b_;
- const LogicalBuffer* buffer_c_;
- const LogicalBuffer* buffer_d_;
- const LogicalBuffer* buffer_e_;
- const LogicalBuffer* buffer_f_;
- const LogicalBuffer* buffer_g_;
- const LogicalBuffer* buffer_h_;
- const LogicalBuffer* buffer_i_;
+ const BufferValue* buffer_a_;
+ const BufferValue* buffer_b_;
+ const BufferValue* buffer_c_;
+ const BufferValue* buffer_d_;
+ const BufferValue* buffer_e_;
+ const BufferValue* buffer_f_;
+ const BufferValue* buffer_g_;
+ const BufferValue* buffer_h_;
+ const BufferValue* buffer_i_;
private:
- // Create a dummy LogicalBuffer to pass to the heap algorithm.
- const LogicalBuffer* DummyLogicalBuffer() {
- const LogicalBuffer::Id id = buffers_.size();
+ // Create a dummy BufferValue to pass to the heap algorithm.
+ const BufferValue* DummyBufferValue() {
+ const BufferValue::Id id = buffers_.size();
auto const0 = builder_.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
- buffers_.emplace_back(MakeUnique<LogicalBuffer>(const0, ShapeIndex{}, id));
+ buffers_.emplace_back(MakeUnique<HloValue>(id, const0, ShapeIndex{}));
return buffers_.back().get();
}
HloComputation::Builder builder_;
- std::vector<std::unique_ptr<LogicalBuffer>> buffers_;
+ std::vector<std::unique_ptr<BufferValue>> buffers_;
};
class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {};
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 17e43c3cb8..63c3dc4a59 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -365,25 +365,38 @@ std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
string HloComputation::ToString(const HloPrintOptions& options) const {
std::ostringstream s;
for (int i = 0; i < options.indent_amount(); i++) {
- s << " ";
+ s << " ";
}
- if (options.print_percent()) {
- s << "%";
+
+ if (!options.is_in_nested_computation()) {
+ if (options.print_percent()) {
+ s << "%";
+ }
+ s << name() << " ";
}
- s << name();
+
if (options.print_program_shape()) {
- s << " " << ShapeUtil::HumanString(ComputeProgramShape());
- }
- s << " {\n";
- for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
- for (int i = 0; i < options.indent_amount(); i++) {
- s << " ";
+ s << ShapeUtil::HumanString(ComputeProgramShape()) << " ";
+ }
+ s << "{\n";
+ {
+ // Print the instructions in this computation.
+ HloPrintOptions new_options = options;
+ new_options.set_indent_amount(options.indent_amount() + 1)
+ .set_is_in_nested_computation(true);
+ CanonicalNameMap name_map;
+ for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
+ for (int i = 0; i < new_options.indent_amount(); i++) {
+ s << " ";
+ }
+ s << (instruction == root_instruction_ ? "ROOT " : "")
+ << instruction->ToStringWithCanonicalNameMap(new_options, &name_map)
+ << "\n";
}
- s << " " << (instruction == root_instruction_ ? "ROOT " : "")
- << instruction->ToString(options) << "\n";
}
+
for (int i = 0; i < options.indent_amount(); i++) {
- s << " ";
+ s << " ";
}
s << "}";
return s.str();
@@ -407,27 +420,37 @@ HloComputationProto HloComputation::ToProto() const {
/* static */ StatusOr<std::unique_ptr<HloComputation>>
HloComputation::CreateFromProto(
- HloModule* module, const HloComputationProto& proto,
+ const HloComputationProto& proto,
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
- std::vector<std::unique_ptr<HloInstruction>> instructions;
tensorflow::gtl::FlatMap<int64, HloInstruction*> instruction_map;
+ tensorflow::gtl::FlatMap<HloInstruction*, int64> to_proto_id;
+ std::vector<std::unique_ptr<HloInstruction>> instructions;
int64 parameter_count = 0;
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloInstruction> instruction,
- HloInstruction::CreateFromProto(module, instruction_proto,
- instruction_map, computation_map));
+ HloInstruction::CreateFromProto(instruction_proto, instruction_map,
+ computation_map));
if (instruction->opcode() == HloOpcode::kParameter) {
parameter_count++;
}
TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
instruction_map[instruction_proto.id()] = instruction.get();
+ to_proto_id[instruction.get()] = instruction_proto.id();
instructions.push_back(std::move(instruction));
}
TF_RET_CHECK(proto.root_id() != -1);
TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
HloInstruction* root = instruction_map.at(proto.root_id());
+
+ // Sort the instructions in the proto id's order.
+ std::sort(instructions.begin(), instructions.end(),
+ [&](const std::unique_ptr<HloInstruction>& a,
+ const std::unique_ptr<HloInstruction>& b) {
+ return to_proto_id[a.get()] < to_proto_id[b.get()];
+ });
+
return WrapUnique(new HloComputation(proto.name(), parameter_count,
&instructions, root,
/*fusion_instruction=*/nullptr));
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 9898355625..ba9d44a9ab 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -157,14 +157,12 @@ class HloComputation {
// Creates a computation from the given proto. Arguments:
//
- // module: the module which will contain the computation. The newly created
- // computation is *not* added to the module, however.
// proto: the proto to convert from.
// computation_map: a map from computation id to HloComputation*. This map
// must contain all computations which the newly constructed computation
// calls.
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
- HloModule* module, const HloComputationProto& proto,
+ const HloComputationProto& proto,
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
// Gets the instructions in this computation.
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index 7b7588f4ba..25469a54c4 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -550,6 +550,108 @@ TEST_F(HloComputationTest, Reachability) {
EXPECT_FALSE(reachability->IsReachable(constant2, copy));
}
+TEST_F(HloComputationTest, Stringification) {
+ const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
+ const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
+ const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
+ const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
+
+ HloComputation::Builder builder("TransposeDot");
+ HloInstruction* x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
+ HloInstruction* y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
+ HloInstruction* reshape =
+ builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ builder.AddInstruction(
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
+
+ auto options = HloPrintOptions().set_print_metadata(false);
+ EXPECT_EQ(computation->ToString(options),
+ R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+ %x = f32[5,10]{1,0} parameter(0)
+ %y = f32[20,10]{1,0} parameter(1)
+ %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
+ ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})");
+}
+
+TEST_F(HloComputationTest, StringificationIndent) {
+ const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
+ const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
+ const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
+ const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
+
+ HloComputation::Builder builder("TransposeDot");
+ HloInstruction* x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
+ HloInstruction* y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
+ HloInstruction* reshape =
+ builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ builder.AddInstruction(
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
+
+ auto options =
+ HloPrintOptions().set_print_metadata(false).set_indent_amount(2);
+ EXPECT_EQ(computation->ToString(options),
+ R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+ %x = f32[5,10]{1,0} parameter(0)
+ %y = f32[20,10]{1,0} parameter(1)
+ %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
+ ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ })");
+}
+
+TEST_F(HloComputationTest, StringificationCanonical) {
+ const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
+ const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
+ const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
+ const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
+
+ HloComputation::Builder builder("TransposeDot");
+ HloInstruction* x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
+ HloInstruction* y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
+ HloInstruction* reshape =
+ builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ builder.AddInstruction(
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
+
+ auto options = HloPrintOptions().set_print_metadata(false);
+ EXPECT_EQ(computation->ToString(options),
+ R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+ %x = f32[5,10]{1,0} parameter(0)
+ %y = f32[20,10]{1,0} parameter(1)
+ %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
+ ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})");
+
+ options = HloPrintOptions().Canonical();
+ EXPECT_EQ(computation->ToString(options), R"(TransposeDot {
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+ ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})");
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 7b552ee5b1..5d05ccfc0b 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -149,7 +149,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
const int64 slice_limits[] = {10, 8, 6, 5, 9};
const int64 slice_strides[] = {1, 1, 1, 1, 1};
TF_ASSERT_OK_AND_ASSIGN(auto literal,
- LiteralTestUtil::CreateRandomLiteral<F32>(
+ Literal::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -172,7 +172,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
HloComputation::Builder builder(TestName());
const int64 dimensions[] = {11, 8, 7, 5, 9};
TF_ASSERT_OK_AND_ASSIGN(auto literal,
- LiteralTestUtil::CreateRandomLiteral<F32>(
+ Literal::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = literal->Literal::CloneToUnique();
HloInstruction* literal_instruction = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 44e4f75f75..94c9c7eabc 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -142,19 +142,25 @@ Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) {
}
Status HloCostAnalysis::HandleParameter(const HloInstruction*) {
+ current_should_compute_bottleneck_time_ = false;
current_properties_[kBytesAccessedKey] = 0;
+ current_properties_[kOptimalSecondsKey] = 0;
return Status::OK();
}
Status HloCostAnalysis::HandleConstant(const HloInstruction*) {
+ current_should_compute_bottleneck_time_ = false;
current_properties_[kBytesAccessedKey] = 0;
+ current_properties_[kOptimalSecondsKey] = 0;
return Status::OK();
}
Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) {
// GetTupleElement forwards a pointer and does not touch each element in the
// output.
+ current_should_compute_bottleneck_time_ = false;
current_properties_[kBytesAccessedKey] = 0;
+ current_properties_[kOptimalSecondsKey] = 0;
return Status::OK();
}
@@ -329,6 +335,7 @@ Status HloCostAnalysis::HandleSelectAndScatter(
Status HloCostAnalysis::HandleBitcast(const HloInstruction*) {
// A bitcast does no computation and touches no memory.
current_properties_[kBytesAccessedKey] = 0;
+ current_properties_[kOptimalSecondsKey] = 0;
return Status::OK();
}
@@ -555,11 +562,13 @@ Status HloCostAnalysis::HandleCall(const HloInstruction* call) {
}
Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) {
- // We can't do anything sane with CustomCalls, since we don't know what they
- // do, and returning an error status will stop iteration over this
- // computation, which is probably also not what we want. So just punt and
- // return OK. This will cause all of the properties to be reported as 0,
- // which is fine.
+ // Mark applicable fields as "unknown", since we don't know what CustomCall
+ // does. This is better than returning an error, which would stop iteration,
+ // and therefore would prevent us from getting *any* stats for a computation
+ // which contains a CustomCall.
+ current_properties_[kOptimalSecondsKey] = -1;
+ current_properties_[kBytesAccessedKey] = -1;
+ current_properties_[kFlopsKey] = -1;
current_should_compute_bottleneck_time_ = false;
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index ed3b654851..0fb65c845a 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -162,6 +162,17 @@ StatusOr<HloInstruction*> MakeConcatHlo(ArraySlice<HloInstruction*> operands,
HloInstruction::CreateConcatenate(concat_shape, operands, dimension));
}
+StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dim_numbers) {
+ HloComputation* computation = lhs->parent();
+ CHECK_EQ(computation, rhs->parent());
+ TF_ASSIGN_OR_RETURN(
+ Shape dot_shape,
+ ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers));
+ return computation->AddInstruction(
+ HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers));
+}
+
StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
CHECK_GT(n, 0);
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index c9a7361a6a..49b1402d68 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -97,6 +97,11 @@ StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
StatusOr<HloInstruction*> MakeConcatHlo(
tensorflow::gtl::ArraySlice<HloInstruction*> operands, int64 dimension);
+// Creates a Dot HLO instruction and adds it to the computation containing `lhs`
+// and `rhs` (both must be in the same computation).
+StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dim_numbers);
+
// -----------------------------------------------------------------------------
// Some other miscellaneous helpers to generate common HLO patterns. All of
// these add all the instructions they generate into the computation containing
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index 3b22c93733..28f861aecc 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace xla {
@@ -88,6 +89,20 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) {
return changed;
}
+// An instruction is considered to be equivalent to another only if they
+// share the exact same set of operands.
+int64 CseHash(const HloInstruction* instruction) {
+ int64 hash = std::hash<int64>()(static_cast<int64>(instruction->opcode()));
+ hash = tensorflow::Hash64Combine(
+ hash, instruction->opcode() == HloOpcode::kGetTupleElement
+ ? instruction->tuple_index()
+ : -1);
+ for (auto operand : instruction->operands()) {
+ hash = tensorflow::Hash64Combine(hash, operand->unique_id());
+ }
+ return hash;
+}
+
} // namespace
StatusOr<bool> HloCSE::Run(HloModule* module) {
@@ -96,6 +111,12 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
eq_instructions = std::equal_to<const HloInstruction*>();
const std::function<bool(const HloComputation*, const HloComputation*)>
eq_computations = std::equal_to<const HloComputation*>();
+
+ auto cse_equal = [&](const HloInstruction* lhs, const HloInstruction* rhs) {
+ return lhs->Identical(*rhs, eq_instructions, eq_computations,
+ is_layout_sensitive_);
+ };
+
for (auto* computation : module->computations()) {
if (only_fusion_computations_ && !computation->IsFusionComputation()) {
continue;
@@ -103,13 +124,17 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
changed |= CombineConstants(computation, is_layout_sensitive_);
- std::list<HloInstruction*> post_order =
- computation->MakeInstructionPostOrder();
- std::set<HloInstruction*> removed_instructions;
- for (auto instruction : post_order) {
- // If the instruction has already been removed by CSE skip over it.
- if (removed_instructions.count(instruction) > 0 ||
- instruction->operand_count() == 0) {
+ // HLO instructions are grouped into equivalency classes by using the
+ // cse_equal predicate defined above. This set holds a representative
+ // instruction for each class.
+ tensorflow::gtl::FlatSet<HloInstruction*, decltype(&CseHash),
+ decltype(cse_equal)>
+ representatives(/*N=*/1024, &CseHash, cse_equal);
+
+ for (auto instruction : computation->MakeInstructionPostOrder()) {
+ // If the instruction has zero operands (constants, parameters, etc.) skip
+ // over it.
+ if (instruction->operand_count() == 0) {
continue;
}
@@ -118,31 +143,16 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
continue;
}
- // An instruction is considered to be equivalent to another only if they
- // share the exact same set of operands. So to find equivalent
- // instructions, we just search among instructions which share operand(0)
- // of this instruction.
- const HloInstruction* operand = instruction->operand(0);
-
- tensorflow::gtl::InlinedVector<HloInstruction*, 8>
- equivalent_instructions;
- for (HloInstruction* user : operand->users()) {
- if (user != instruction && !user->HasSideEffect() &&
- user->Identical(*instruction, eq_instructions, eq_computations,
- is_layout_sensitive_)) {
- equivalent_instructions.push_back(user);
- }
- }
-
- // Replace all equivalent instructions with this instruction.
- for (HloInstruction* equivalent_instruction : equivalent_instructions) {
+ auto it = representatives.find(instruction);
+ if (it != representatives.end()) {
+ HloInstruction* equivalent_instruction = *it;
TF_RETURN_IF_ERROR(
- equivalent_instruction->ReplaceAllUsesWith(instruction));
- TF_RETURN_IF_ERROR(
- computation->RemoveInstruction(equivalent_instruction));
- removed_instructions.insert(equivalent_instruction);
+ instruction->ReplaceAllUsesWith(equivalent_instruction));
+ TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
changed = true;
+ continue;
}
+ representatives.insert(instruction);
}
}
return changed;
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index df8853f34f..a04b4f4dcf 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -72,7 +72,7 @@ TEST_F(HloCseTest, CombineTwoConstants) {
auto result = ExecuteAndTransfer(std::move(module), {});
auto expected = Literal::CreateR0<float>(84.0);
- LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
@@ -104,7 +104,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
auto result = ExecuteAndTransfer(std::move(module), {});
auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
- LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
@@ -134,7 +134,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
auto result = ExecuteAndTransfer(std::move(module), {});
auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
- LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index e7425c8ba7..ff7d07ee16 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -52,12 +52,11 @@ namespace xla {
namespace {
using tensorflow::gtl::ArraySlice;
-using tensorflow::gtl::FlatSet;
template <typename OperandT>
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
- const Literal& lhs_literal,
- const Literal& rhs_literal) {
+ LiteralSlice lhs_literal,
+ LiteralSlice rhs_literal) {
std::function<bool(OperandT, OperandT)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
@@ -106,8 +105,8 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
template <>
StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
- const Shape& shape, HloOpcode opcode, const Literal& lhs_literal,
- const Literal& rhs_literal) {
+ const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal,
+ LiteralSlice rhs_literal) {
std::function<bool(complex64, complex64)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index cc16446778..ae5b5e0412 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -82,9 +82,9 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
auto element_type = expected->shape().element_type();
if (element_type == F32 || element_type == F64) {
ErrorSpec error(aabs);
- LiteralTestUtil::ExpectNear(*expected, *result, error);
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error));
} else {
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
}
@@ -100,7 +100,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
std::unique_ptr<Literal> result = Evaluate();
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
bool use_bfloat16_;
@@ -129,7 +129,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
auto expected = Literal::CreateR2<float>({{0, 4}, {2, 4}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
@@ -150,7 +150,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
auto expected = Literal::CreateR2<float>({{0, 0}, {1, 1}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs select
@@ -175,7 +175,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
auto expected = Literal::CreateR2<float>({{2, 5}, {0, 4}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
@@ -307,7 +307,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
auto expected = Literal::CreateR2<int64>({{4, -16}, {-196, 12}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
// Verifies Reshape operation is correctly evaluated.
@@ -315,7 +315,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
HloComputation::Builder b(TestName());
const int64 dimensions[] = {11, 8, 7, 5, 9};
TF_ASSERT_OK_AND_ASSIGN(auto literal,
- LiteralTestUtil::CreateRandomLiteral<F32>(
+ Literal::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = literal->CloneToUnique();
HloInstruction* literal_instruction =
@@ -351,7 +351,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) {
std::unique_ptr<Literal> result = Evaluate({});
- LiteralTestUtil::ExpectEqual(*result, *output_literal);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
}
TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
@@ -370,7 +370,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
std::unique_ptr<Literal> result = Evaluate({});
- LiteralTestUtil::ExpectEqual(*result, *output_literal);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
}
TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
@@ -392,7 +392,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
auto expected =
Literal::CreateR2<int64>({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
@@ -413,7 +413,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
std::unique_ptr<Literal> result = Evaluate();
auto expected = Literal::CreateR1<int64>({100, 200});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
@@ -432,7 +432,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
std::unique_ptr<Literal> result = Evaluate();
- LiteralTestUtil::ExpectEqual(*result, *expected);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
@@ -452,7 +452,7 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
std::unique_ptr<Literal> result = Evaluate();
- LiteralTestUtil::ExpectEqual(*result, *expected);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
PaddingConfig CreatePaddingConfig(
@@ -490,7 +490,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
auto expected = Literal::CreateR2<int32>(
{{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
@@ -525,7 +525,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
auto expected = Literal::CreateR4FromArray4D<float>(*expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, NegativePadding2D) {
@@ -567,7 +567,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
(*expected_array)(0, 4) = 2.718f;
auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
- LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(0x1.0P-5));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0x1.0P-5)));
}
TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
@@ -606,7 +606,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
auto expected_array = MakeUnique<Array2D<float>>(0, 9);
auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
@@ -651,7 +651,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
// clang-format on
auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
@@ -688,7 +688,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
auto expected = Literal::CreateR1<float>({22.f, 28.f});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
@@ -737,7 +737,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
});
auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, SimpleConv1D) {
@@ -785,7 +785,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
auto expected = Literal::CreateR3FromArray3D<float>(expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
@@ -847,7 +847,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
// clang-format on
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
@@ -927,7 +927,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
auto expected = Literal::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
@@ -1004,7 +1004,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
auto expected = Literal::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
@@ -1067,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
}));
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
@@ -1131,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
}));
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest,
@@ -1203,7 +1203,7 @@ TEST_P(HloEvaluatorTest,
}));
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
@@ -1319,7 +1319,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
auto expected = Literal::CreateR1<float>({6, 18});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, ReduceWindowMax) {
@@ -1370,7 +1370,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
std::unique_ptr<Literal> result = Evaluate();
auto expected = Literal::CreateR2<float>({{6, 7}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
@@ -1427,7 +1427,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
std::unique_ptr<Literal> result = Evaluate();
auto expected = Literal::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
@@ -1490,7 +1490,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
std::unique_ptr<Literal> result_literal =
Literal::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
- LiteralTestUtil::ExpectEqual(*result_literal, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result));
}
TEST_P(HloEvaluatorTest, StridedSlice) {
@@ -1523,7 +1523,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
{19},
});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DynamicSlice) {
@@ -1556,7 +1556,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
{6, 7, 8},
});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
// Verifies that the HloEvaluator's implementation goes along with existing
@@ -1591,7 +1591,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
{6, 7, 8},
});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
@@ -1627,7 +1627,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
{5, -6, -7},
});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, SetAndGetTuples) {
@@ -1662,7 +1662,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
{5, 6, 7},
});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
@@ -1703,7 +1703,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
result_inner_literal.get(),
});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, Reverse) {
@@ -1756,7 +1756,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
});
// clang-format on
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
@@ -1776,8 +1776,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
add, {{param0, Literal::CreateR1<float>({1, 2, 3, 4}).get()},
{square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
TF_ASSERT_OK(result.status());
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<float>({11, 22, 33, 44}),
- *result.ValueOrDie());
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *Literal::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
}
// Check that EvaluateWithSubstitutions works if one of the operands to the op
@@ -1800,8 +1800,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) {
auto result = evaluator.EvaluateWithSubstitutions(
add, {{square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
TF_ASSERT_OK(result.status());
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<float>({11, 22, 33, 44}),
- *result.ValueOrDie());
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *Literal::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
@@ -1823,9 +1823,9 @@ ENTRY main {
std::unique_ptr<Literal> operand =
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
- LiteralTestUtil::ExpectEqual(
- *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
@@ -1847,9 +1847,9 @@ ENTRY main {
std::unique_ptr<Literal> operand =
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
@@ -1872,10 +1872,10 @@ ENTRY main {
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
Literal::CreateR2<int32>({{0, 2}, {2, 1}});
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR3<int32>(
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
@@ -1900,9 +1900,9 @@ ENTRY main {
{{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
Literal::CreateR2<int32>({{0, 0}, {1, 0}});
- LiteralTestUtil::ExpectEqual(
- *Literal::CreateR2<int32>({{-1, 1}, {-4, 4}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, 1}, {-4, 4}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest,
@@ -1928,9 +1928,9 @@ ENTRY main {
{{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
Literal::CreateR2<int32>({{0, 0}, {1, 0}});
- LiteralTestUtil::ExpectEqual(
- *Literal::CreateR2<int32>({{-2, 2}, {-1, 1}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-2, 2}, {-1, 1}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
@@ -1952,9 +1952,9 @@ ENTRY main {
std::unique_ptr<Literal> operand =
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
- LiteralTestUtil::ExpectEqual(
- *Literal::CreateR2<int32>({{5}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{5}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
@@ -1977,9 +1977,9 @@ ENTRY main {
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
Literal::CreateR2<int32>({{2, 1}, {1, 1}});
- LiteralTestUtil::ExpectEqual(
- *Literal::CreateR3<int32>({{{8}}, {{5}}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{8}}, {{5}}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
@@ -2000,9 +2000,34 @@ ENTRY main {
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}});
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
- LiteralTestUtil::ExpectEqual(
- *Literal::CreateR2<int32>({{}, {}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{}, {}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
+ const string hlo_text = R"(
+HloModule GatherXd
+
+ENTRY main {
+ operand = s32[3] parameter(0)
+ indices = s32[2,2,1] parameter(1)
+ ROOT gather = s32[2,2] gather(operand, indices),
+ output_window_dims={},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=2,
+ window_bounds={1}
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+
+ std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({0, 1, 2});
+ std::unique_ptr<Literal> gather_indices =
+ Literal::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{0, 1}, {2, 1}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index f1cb363478..0e4ef08ad3 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -256,6 +256,29 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleExpm1(HloInstruction* expm1) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[expm1],
+ ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) {
+ return std::expm1(elem_operand);
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleExpm1(HloInstruction* floor) {
+ return InvalidArgument("Unsupported type for Expm1");
+ }
+
+ Status HandleExpm1(HloInstruction* floor) override {
+ return HandleExpm1<ReturnT>(floor);
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleFloor(HloInstruction* floor) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[floor],
@@ -284,6 +307,29 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
+ template <
+ typename NativeT,
+ typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleLog1p(HloInstruction* expm1) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[expm1],
+ ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) {
+ return std::log1p(elem_operand);
+ }));
+ return Status::OK();
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleLog1p(HloInstruction* floor) {
+ return InvalidArgument("Unsupported type for Log1p");
+ }
+
+ Status HandleLog1p(HloInstruction* floor) override {
+ return HandleLog1p<ReturnT>(floor);
+ }
+
template <typename NativeT,
typename std::enable_if<
std::is_integral<NativeT>::value &&
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
index a0cb28246d..dcc4583165 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
@@ -16,34 +16,16 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
-class HloExecutionProfileTest : public HloTestBase {
- protected:
- static constexpr int64 kInstructionCyclesIndex = 0;
- static constexpr int64 kInstructionNameIndex = 19;
-};
+using tensorflow::strings::StrCat;
+using ::testing::AllOf;
+using ::testing::ContainsRegex;
-// Splits `lines` into a sequence of lines delimited by newlines and then split
-// each of those lines into a sequence of words delimited by spaces. Filter out
-// empty words.
-std::vector<std::vector<string>> SplitIntoLinesAndWords(
- tensorflow::StringPiece lines) {
- std::vector<std::vector<string>> result;
- for (const string& line : tensorflow::str_util::Split(lines, '\n')) {
- std::vector<string> words;
- for (const string& word : tensorflow::str_util::Split(line, ' ')) {
- if (!word.empty()) {
- words.push_back(word);
- }
- }
- result.push_back(std::move(words));
- }
-
- return result;
-}
+class HloExecutionProfileTest : public HloTestBase {};
TEST_F(HloExecutionProfileTest, Basic) {
std::unique_ptr<HloModule> hlo_module = CreateNewModule();
@@ -84,20 +66,12 @@ TEST_F(HloExecutionProfileTest, Basic) {
execution_profile.SetCyclesTakenBy(add_instruction, add_cycles);
execution_profile.SetCyclesTakenBy(dot_instruction, dot_cycles);
- string rendered_profile = execution_profile.ToString(
- backend().default_stream_executor()->GetDeviceDescription());
- std::vector<std::vector<string>> lines_and_words =
- SplitIntoLinesAndWords(rendered_profile);
- ASSERT_EQ(lines_and_words.size(), 8);
-
- const std::vector<string>& line_2 = lines_and_words[2];
- const std::vector<string>& line_3 = lines_and_words[3];
-
- EXPECT_EQ(line_2[kInstructionCyclesIndex], std::to_string(dot_cycles));
- EXPECT_EQ(line_2[kInstructionNameIndex], '%' + dot_instruction->name());
-
- EXPECT_EQ(line_3[kInstructionCyclesIndex], std::to_string(add_cycles));
- EXPECT_EQ(line_3[kInstructionNameIndex], '%' + add_instruction->name());
+ EXPECT_THAT(execution_profile.ToString(
+ backend().default_stream_executor()->GetDeviceDescription()),
+ AllOf(ContainsRegex(StrCat(dot_cycles, R"(\b.*%)",
+ dot_instruction->name())),
+ ContainsRegex(StrCat(add_cycles, R"(\b.*%)",
+ add_instruction->name()))));
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index b6b0387672..17e3c405f1 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -825,7 +825,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
*elem_count *= dim;
}
}
- if (elem_count.has_value() && *elem_count <= 8) {
+ if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
return Printf("%s (%s)", constant->literal().ToString(),
ShapeUtil::HumanString(constant->shape()));
}
@@ -925,6 +925,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kDivide:
case HloOpcode::kEq:
case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kGe:
case HloOpcode::kGt:
@@ -932,6 +933,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kIsFinite:
case HloOpcode::kLe:
case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
@@ -1102,7 +1104,8 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
// Get the instruction's extra attributes excluding the names of its
// subcomputations, since those are drawn explicitly in the graph.
for (const auto& line : instr->ExtraAttributesToString(
- HloPrintOptions().set_print_subcomputation_references(false))) {
+ HloPrintOptions().set_print_subcomputation_mode(
+ HloPrintOptions::PrintSubcomputationMode::kOff))) {
lines.push_back(HtmlLikeStringSanitize(line));
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 857cd39adb..31aff008a4 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -51,7 +51,7 @@ using ::tensorflow::strings::StrCat;
/* static */
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
- HloModule* module, const HloInstructionProto& proto,
+ const HloInstructionProto& proto,
const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
TF_RET_CHECK(!proto.opcode().empty());
@@ -257,10 +257,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kCos:
case HloOpcode::kClz:
case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kReal:
@@ -1245,10 +1247,12 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kFloor:
case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kReal:
@@ -1557,6 +1561,8 @@ const Literal& HloInstruction::literal() const {
return *literal_;
}
+bool HloInstruction::HasLiteral() const { return literal_ != nullptr; }
+
bool HloInstruction::CanHaveDimensionsField() const {
return (opcode() == HloOpcode::kReverse ||
opcode() == HloOpcode::kConcatenate ||
@@ -1697,6 +1703,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kDivide:
case HloOpcode::kEq:
case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kGe:
case HloOpcode::kGt:
@@ -1704,6 +1711,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kIsFinite:
case HloOpcode::kLe:
case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
case HloOpcode::kAnd:
case HloOpcode::kNot:
case HloOpcode::kOr:
@@ -2098,13 +2106,40 @@ string PrintName(const string& name, const HloPrintOptions& options) {
} // namespace
string HloInstruction::ToString(const HloPrintOptions& options) const {
- string result =
- StrCat(PrintName(name(), options), " = ",
- ShapeUtil::HumanStringWithLayout(shape()), " ",
- HloOpcodeString(opcode()), "(", OperandsToString(options), ")");
+ CanonicalNameMap new_map;
+ return ToStringWithCanonicalNameMap(options, &new_map);
+}
+
+string HloInstruction::ToStringWithCanonicalNameMap(
+ const HloPrintOptions& options,
+ CanonicalNameMap* canonical_name_map) const {
+ string result = "";
+
+ // Logic to print the instruction name (e.g. "%foo = ").
+ if (options.canonicalize_instruction_names()) {
+ if (options.is_in_nested_computation()) {
+ // If we are canonicalizing instruction names and this is a top-level
+ // HloInstruction::ToString() call, don't print an instruction name.
+ StrAppend(&result,
+ PrintName(canonical_name_map->LookupOrInsert(name()), options),
+ " = ");
+ }
+ } else {
+ StrAppend(&result, PrintName(name(), options), " = ");
+ }
+
+ // Print opcode, operand(s) and shape.
+ StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ",
+ HloOpcodeString(opcode()), "(",
+ OperandsToStringWithCanonicalNameMap(options, canonical_name_map),
+ ")");
+
+ // Print additional attributes. If an instruction contains a subcomputation,
+ // the subcomputation is also printed here.
for (const string& extra : ExtraAttributesToString(options)) {
StrAppend(&result, ", ", extra);
}
+
if (options.print_metadata() &&
(!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
!metadata_.source_file().empty())) {
@@ -2117,6 +2152,13 @@ string HloInstruction::ToString(const HloPrintOptions& options) const {
}
string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
+ CanonicalNameMap new_map;
+ return OperandsToStringWithCanonicalNameMap(options, &new_map);
+}
+
+string HloInstruction::OperandsToStringWithCanonicalNameMap(
+ const HloPrintOptions& options,
+ CanonicalNameMap* canonical_name_map) const {
string operands;
if (opcode() == HloOpcode::kConstant) {
// For constants, show the actual value in place of an empty operand list.
@@ -2156,7 +2198,14 @@ string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
if (options.print_operand_shape()) {
str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
}
- if (!options.compact_operands()) {
+
+ // In a top-level HloInstruction::ToString() call, the operand name is not
+ // part of the canonical string.
+ if (options.canonicalize_instruction_names() &&
+ options.is_in_nested_computation()) {
+ str.push_back(PrintName(
+ canonical_name_map->LookupOrInsert(operand->name()), options));
+ } else if (!options.compact_operands()) {
str.push_back(PrintName(operand->name(), options));
}
StrAppend(out, Join(str, " "));
@@ -2225,7 +2274,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}"));
}
- if (options.print_subcomputation_references()) {
+ if (options.print_subcomputation_mode() ==
+ HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
if (opcode() == HloOpcode::kWhile) {
extra.push_back(
StrCat("condition=", PrintName(while_condition()->name(), options)));
@@ -2253,8 +2303,45 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
PrintName(computation->name(), options));
})));
}
+ } else if (options.print_subcomputation_mode() ==
+ HloPrintOptions::PrintSubcomputationMode::kFullBodies) {
+ HloPrintOptions new_options = options;
+ new_options.set_is_in_nested_computation(true);
+ switch (opcode()) {
+ case HloOpcode::kWhile:
+ extra.push_back(
+ StrCat("condition=\n", while_condition()->ToString(new_options)));
+ extra.push_back(StrCat("body=\n", while_body()->ToString(new_options)));
+ break;
+ case HloOpcode::kSelectAndScatter:
+ extra.push_back(StrCat("select=\n", select()->ToString(new_options)));
+ extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
+ break;
+ case HloOpcode::kConditional:
+ extra.push_back(StrCat("true_computation=\n",
+ true_computation()->ToString(new_options)));
+ extra.push_back(StrCat("false_computation=\n",
+ false_computation()->ToString(new_options)));
+ break;
+ case HloOpcode::kCall:
+ case HloOpcode::kMap:
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kReduce:
+ extra.push_back(
+ StrCat("to_apply=\n", to_apply()->ToString(new_options)));
+ break;
+ default:
+ if (!called_computations().empty()) {
+ extra.push_back(
+ StrCat("calls=\n",
+ Join(called_computations(), ", ",
+ [&](string* out, const HloComputation* computation) {
+ StrAppend(out, computation->ToString(new_options));
+ })));
+ }
+ break;
+ }
}
-
if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv ||
opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) {
extra.push_back(StrCat("channel_id=", channel_id_));
@@ -2292,7 +2379,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
}
// By contract, we print the custom call target even if
- // !options.print_subcomputation_references(), because the call target is not
+ // options.print_subcomputation_mode() == kOff, because the call target is not
// an HloComputation.
if (opcode() == HloOpcode::kCustomCall) {
extra.push_back(
@@ -2394,6 +2481,10 @@ HloInstructionProto HloInstruction::ToProto() const {
proto.add_fft_length(fft_len);
}
+ if (has_sharding()) {
+ *proto.mutable_sharding() = sharding().ToProto();
+ }
+
proto.set_channel_name(channel_name_);
proto.set_cost_estimate_ns(cost_estimate_ns_);
@@ -2614,6 +2705,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleNegate(this);
case HloOpcode::kExp:
return visitor->HandleExp(this);
+ case HloOpcode::kExpm1:
+ return visitor->HandleExpm1(this);
case HloOpcode::kFloor:
return visitor->HandleFloor(this);
case HloOpcode::kCeil:
@@ -2622,6 +2715,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleClz(this);
case HloOpcode::kLog:
return visitor->HandleLog(this);
+ case HloOpcode::kLog1p:
+ return visitor->HandleLog1p(this);
case HloOpcode::kTanh:
return visitor->HandleTanh(this);
case HloOpcode::kCos:
@@ -2968,10 +3063,12 @@ bool HloInstruction::IsElementwise() const {
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kReal:
@@ -3036,7 +3133,7 @@ bool HloInstruction::IsElementwise() const {
bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const {
CHECK(IsElementwise());
- return !ShapeUtil::Equal(shape(), operand(operand_idx)->shape());
+ return !ShapeUtil::SameDimensions(shape(), operand(operand_idx)->shape());
}
namespace {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 14be58d069..0089cae51a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -60,23 +60,31 @@ class HloModule;
// A bunch of switches that control how the hlo text should be printed.
class HloPrintOptions {
public:
+ enum class PrintSubcomputationMode {
+ kOff, // Do not print anything about subcomputations.
+ kNameOnly, // Only print the name of subcomputations.
+ kFullBodies, // Print the full bodies of subcomputations.
+ };
+
// Constructs the default print options: don't print large constants, don't
// compact operands, no indentation.
HloPrintOptions()
: print_large_constants_(false),
- print_subcomputation_references_(true),
+ print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly),
print_metadata_(true),
print_backend_config_(true),
compact_operands_(false),
print_operand_shape_(true),
print_program_shape_(true),
print_percent_(true),
- indent_amount_(0) {}
+ canonicalize_instruction_names_(false),
+ indent_amount_(0),
+ is_in_nested_computation_(false) {}
static HloPrintOptions ShortParsable() {
return HloPrintOptions()
.set_print_large_constants(true)
- .set_print_subcomputation_references(true)
+ .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly)
.set_print_metadata(false)
.set_print_backend_config(false)
.set_print_operand_shape(false)
@@ -84,20 +92,28 @@ class HloPrintOptions {
.set_print_percent(false);
}
+ // Options to produce the canonical string representing an isomorphic
+ // computation graph.
+ static HloPrintOptions Canonical() {
+ return HloPrintOptions()
+ .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
+ .set_print_metadata(false)
+ .set_compact_operands(true)
+ .set_print_operand_shape(true)
+ .set_print_program_shape(false)
+ .set_print_percent(false)
+ .set_canonicalize_instruction_names(true);
+ }
+
// If true, large constants will be printed out.
HloPrintOptions& set_print_large_constants(bool value) {
print_large_constants_ = value;
return *this;
}
- // If true, the names of subcomputations (e.g. a fusion node's fused
- // computation) won't be printed. This makes the resulting text not parsable.
- //
- // A CustomCall's call target is printed even if
- // print_subcomputation_references is false, because the call target isn't an
- // HloComputation.
- HloPrintOptions& set_print_subcomputation_references(bool value) {
- print_subcomputation_references_ = value;
+ HloPrintOptions& set_print_subcomputation_mode(
+ PrintSubcomputationMode value) {
+ print_subcomputation_mode_ = value;
return *this;
}
@@ -138,15 +154,29 @@ class HloPrintOptions {
return *this;
}
+ // If true, canonicalizes instructions' name. Instead of using "%foo.1" as
+ // the name of an instruction, we use "%tmp_1", "%tmp_2" etc.
+ HloPrintOptions& set_canonicalize_instruction_names(bool value) {
+ canonicalize_instruction_names_ = value;
+ return *this;
+ }
+
// The indent of the hlo text block.
HloPrintOptions& set_indent_amount(int value) {
indent_amount_ = value;
return *this;
}
+ // If true, indicates the instruction being printed is inside a nested
+ // computation.
+ HloPrintOptions& set_is_in_nested_computation(bool value) {
+ is_in_nested_computation_ = value;
+ return *this;
+ }
+
bool print_large_constants() const { return print_large_constants_; }
- bool print_subcomputation_references() const {
- return print_subcomputation_references_;
+ PrintSubcomputationMode print_subcomputation_mode() const {
+ return print_subcomputation_mode_;
}
bool print_metadata() const { return print_metadata_; }
bool print_backend_config() const { return print_metadata_; }
@@ -154,18 +184,51 @@ class HloPrintOptions {
bool print_operand_shape() const { return print_operand_shape_; }
bool print_program_shape() const { return print_program_shape_; }
bool print_percent() const { return print_percent_; }
+ bool canonicalize_instruction_names() const {
+ return canonicalize_instruction_names_;
+ }
int indent_amount() const { return indent_amount_; }
+ int is_in_nested_computation() const { return is_in_nested_computation_; }
private:
bool print_large_constants_;
- bool print_subcomputation_references_;
+ PrintSubcomputationMode print_subcomputation_mode_;
bool print_metadata_;
bool print_backend_config_;
bool compact_operands_;
bool print_operand_shape_;
bool print_program_shape_;
bool print_percent_;
+ bool canonicalize_instruction_names_;
int indent_amount_;
+ bool is_in_nested_computation_;
+};
+
+// For canonical string output, we need to have a canonical way to rename
+// each instruction and its operands. Each operand is renamed as "tmp_<xxx>",
+// where <xxx> is an index starting from 0.
+class CanonicalNameMap {
+ public:
+ CanonicalNameMap() : index(0) {}
+
+ string LookupOrInsert(const string& old_name) {
+ auto iter = canonical_name_map.find(old_name);
+ if (iter != canonical_name_map.end()) {
+ return iter->second;
+ }
+
+ string new_name = tensorflow::strings::StrCat("tmp_", index++);
+ canonical_name_map[old_name] = new_name;
+ return new_name;
+ }
+ void Clear() {
+ canonical_name_map.clear();
+ index = 0;
+ }
+
+ private:
+ int64 index;
+ tensorflow::gtl::FlatMap<string, string> canonical_name_map;
};
// HLO instructions are the IR used by the high-level compiler.
@@ -185,8 +248,6 @@ class HloInstruction {
// Creates an instruction from the given proto. Arguments:
//
- // module: the module which will contain the instruction. The newly created
- // instruction is *not* added to the module or any computation, however.
// proto: the proto to convert from.
// instruction_map: a map from instruction id to HloInstruction*. This map
// must contain all operands of the newly constructed instruction.
@@ -194,7 +255,7 @@ class HloInstruction {
// must contain all computations which the newly constructed instruction
// calls.
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
- HloModule* module, const HloInstructionProto& proto,
+ const HloInstructionProto& proto,
const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
@@ -706,6 +767,9 @@ class HloInstruction {
// Note: only constant and parameter opcodes have an associated literal.
const Literal& literal() const;
+ // Returns whether there is literal associated with this instruction.
+ bool HasLiteral() const;
+
// Returns the parameter number associated with this instruction.
//
// Note: only parameter opcodes have an associated parameter number.
@@ -1330,6 +1394,24 @@ class HloInstruction {
const ShapeIndex& shape_index = {});
private:
+ // Prints an instruction to a string.
+ //
+ // The canonical string representation needs to name operands and instruction
+ // names in a consistent way. This is implemented through the
+ // canonical_name_map.
+ string ToStringWithCanonicalNameMap(
+ const HloPrintOptions& options,
+ CanonicalNameMap* canonical_name_map) const;
+
+ // Prints an operand to a string.
+ string OperandsToStringWithCanonicalNameMap(
+ const HloPrintOptions& options,
+ CanonicalNameMap* canonical_name_map) const;
+
+ // Allow HloInstruction to access the ToStringWithCanonicalNameMap() and
+ // OperandsToStringWithCanonicalNameMap() functions.
+ friend class HloComputation;
+
enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
// Helper class for computing OperandElementUse for kFusion.
@@ -1576,13 +1658,20 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
// an HloInstruction* or a const HloInstruction*.
// To make the iteration order over the map deterministic, the comparator
// should not be using the pointer values, but rather an intrinsic property of
-// the hlo.
+// the hlo. Exception: null pointer values compare less than non-null.
//
// Note that this cannot be used for HLO instructions across multiple modules
// since the id of HLO instructions are only unique within each HLO module.
struct HloPtrComparator {
bool operator()(const HloInstruction* const& lhs,
const HloInstruction* const& rhs) const {
+ if (rhs == nullptr) {
+ // Nothing compares less than nullptr.
+ return false;
+ }
+ if (lhs == nullptr) {
+ return true;
+ }
return lhs->unique_id() < rhs->unique_id();
}
};
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 909cdc0b62..a61c472c72 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1336,5 +1336,163 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
"index_vector_dim=2, window_bounds={30,29,28,27,26}");
}
+TEST_F(HloInstructionTest, CanonnicalStringificationFusion) {
+ // Tests stringification of a simple op, fusion, while, and conditional.
+ const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
+ const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
+ const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
+ const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
+
+ HloComputation::Builder builder("TransposeDot");
+ HloInstruction* x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
+ HloInstruction* y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
+ HloInstruction* reshape =
+ builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+
+ auto options = HloPrintOptions().Canonical();
+
+ EXPECT_EQ(dot->ToString(options),
+ "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), "
+ "lhs_contracting_dims={1}, rhs_contracting_dims={0}");
+
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* fusion = computation->CreateFusionInstruction(
+ {dot, reshape}, HloInstruction::FusionKind::kLoop);
+
+ EXPECT_EQ(
+ fusion->ToString(options),
+ R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls=
+{
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+ ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})");
+}
+
+TEST_F(HloInstructionTest, CanonnicalStringificationWhile) {
+ // Tests stringification of a simple op, fusion, while, and conditional.
+ const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
+ const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
+ const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
+ const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
+
+ HloComputation::Builder builder("TransposeDot");
+ HloInstruction* x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
+ HloInstruction* y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
+ HloInstruction* reshape =
+ builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
+ computation->CreateFusionInstruction({dot, reshape},
+ HloInstruction::FusionKind::kLoop);
+
+ HloInstruction* loop = builder.AddInstruction(
+ HloInstruction::CreateWhile(sout, computation, computation, x));
+
+ auto options = HloPrintOptions().Canonical();
+ EXPECT_EQ(loop->ToString(options),
+ R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
+{
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
+ {
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+ ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+}, body=
+{
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
+ {
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+ ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+})");
+}
+
+TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
+ // Tests stringification of a simple op, fusion, while, and conditional.
+ const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
+ const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
+ const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
+ const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
+
+ HloComputation::Builder builder("TransposeDot");
+ HloInstruction* x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
+ HloInstruction* y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
+ HloInstruction* reshape =
+ builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
+ computation->CreateFusionInstruction({dot, reshape},
+ HloInstruction::FusionKind::kLoop);
+
+ builder.AddInstruction(
+ HloInstruction::CreateWhile(sout, computation, computation, x));
+
+ auto pred = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction* conditional =
+ builder.AddInstruction(HloInstruction::CreateConditional(
+ sout, pred, x, computation, x, computation));
+ auto options = HloPrintOptions().Canonical();
+ EXPECT_EQ(
+ conditional->ToString(options),
+ R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation=
+{
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
+ {
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+ ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+}, false_computation=
+{
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
+ {
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+ ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+})");
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 5308fb5848..fbf1d58007 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -266,24 +266,44 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
<< ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
<< ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
- auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
- module_config);
-
tensorflow::gtl::FlatMap<int64, HloComputation*> computation_map;
+ tensorflow::gtl::FlatMap<HloComputation*, int64> to_proto_id;
+ std::vector<std::unique_ptr<HloComputation>> computations;
+ HloComputation* entry = nullptr;
for (const HloComputationProto& computation_proto : proto.computations()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> computation,
- HloComputation::CreateFromProto(
- module.get(), computation_proto, computation_map));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloComputation> computation,
+ HloComputation::CreateFromProto(computation_proto, computation_map));
CHECK_NE(computation.get(), nullptr);
int64 computation_id = computation_proto.id();
TF_RET_CHECK(computation_id != -1);
TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
+ computation_map[computation_id] = computation.get();
+ to_proto_id[computation.get()] = computation_id;
+ if (computation_id == proto.entry_computation_id()) {
+ entry = computation.get();
+ }
+ computations.push_back(std::move(computation));
+ }
+ TF_RET_CHECK(entry != nullptr);
+
+ auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
+ module_config);
+
+ // Sort the computations in the proto id's order.
+ std::sort(computations.begin(), computations.end(),
+ [&](const std::unique_ptr<HloComputation>& a,
+ const std::unique_ptr<HloComputation>& b) {
+ return to_proto_id[a.get()] < to_proto_id[b.get()];
+ });
+
+ // Add sorted computations to the module.
+ for (auto& computation : computations) {
+ bool is_entry = computation.get() == entry;
// Don't uniquify names because we want names to be stable across
// serialization and deserialization.
- computation_map[computation_id] = module->AddComputationInternal(
- std::move(computation),
- /*is_entry=*/proto.entry_computation_id() == computation_id,
- /*uniquify_names=*/false);
+ module->AddComputationInternal(std::move(computation), is_entry,
+ /*uniquify_names=*/false);
}
TF_RET_CHECK(module->entry_computation_ != nullptr);
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 54c34ce116..a41cfa7591 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
+#include <sstream>
#include <string>
#include <utility>
@@ -47,6 +48,9 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const {
case ComputationKind::kConditionalFalse:
repr += ":CONDITIONAL_FALSE";
break;
+ case ComputationKind::kCallFunction:
+ repr += ":CALL";
+ break;
}
return repr;
}
@@ -107,6 +111,31 @@ Status HloModuleGroupMetadata::Build() {
TF_RETURN_IF_ERROR(computation->Accept(visitor));
}
}
+ TF_RETURN_IF_ERROR(VerifyCompanionSets());
+ return Status::OK();
+}
+
+Status HloModuleGroupMetadata::VerifyCompanionSets() const {
+ // TODO(dlibenzi): Migrate this to use the device instead of module ID, once
+ // the kDomain CL goes in.
+ for (const auto& companions : companion_sets_) {
+ // A companion set must be composed at most of an instruction per
+ // device/module.
+ std::unordered_set<int64> devices;
+ for (HloInstruction* instruction : *companions) {
+ int64 device = GetModuleId(instruction->parent()->parent());
+ if (!devices.insert(device).second) {
+ std::stringstream ss;
+ ss << "Companion set:" << std::endl;
+ for (HloInstruction* hlo : *companions) {
+ ss << " " << hlo->name() << " ("
+ << GetModuleId(hlo->parent()->parent()) << ")" << std::endl;
+ }
+ ss << "has multiple instructions on the same device";
+ return FailedPrecondition("%s", ss.str().c_str());
+ }
+ }
+ }
return Status::OK();
}
@@ -206,6 +235,9 @@ Status HloModuleGroupMetadata::RecordInstructions() {
TrackedInstruction(hlo, ComputationKind::kConditionalTrue);
tracked_instructions_[hlo->false_computation()] =
TrackedInstruction(hlo, ComputationKind::kConditionalFalse);
+ } else if (hlo->opcode() == HloOpcode::kCall) {
+ tracked_instructions_[hlo->to_apply()] =
+ TrackedInstruction(hlo, ComputationKind::kCallFunction);
}
if (!IsChannelInstruction(hlo)) {
return Status::OK();
@@ -258,7 +290,8 @@ Status HloModuleGroupMetadata::RecordInstructions() {
Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
HloInstruction* instruction2) {
TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile ||
- instruction1->opcode() == HloOpcode::kConditional);
+ instruction1->opcode() == HloOpcode::kConditional ||
+ instruction1->opcode() == HloOpcode::kCall);
VLOG(2) << "adding as companions:" << instruction1->ToString() << " and "
<< instruction2->ToString();
@@ -336,21 +369,11 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
}
}
- // Check if channel instructions are used only in allowed computations.
- const auto allowed = [this](HloInstruction* hlo) {
- HloComputation* computation = hlo->parent();
- const HloModule* module = computation->parent();
- if (module->entry_computation() == computation ||
- tracked_instructions_.count(computation) > 0) {
- return true;
- }
- return false;
- };
for (const Channel& channel : channels_) {
- if (!allowed(channel.send) || !allowed(channel.send_done) ||
- !allowed(channel.recv) || !allowed(channel.recv_done)) {
- return FailedPrecondition("channel is used in disallowed computation");
- }
+ TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send));
+ TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send_done));
+ TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv));
+ TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv_done));
}
// Check if the nest levels match for each channel.
for (const Channel& channel : channels_) {
@@ -368,4 +391,15 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
return Status::OK();
}
+Status HloModuleGroupMetadata::CheckCommunicatingInstruction(
+ HloInstruction* instruction) const {
+ HloComputation* computation = instruction->parent();
+ const HloModule* module = computation->parent();
+ if (module->entry_computation() == computation ||
+ tracked_instructions_.count(computation) > 0) {
+ return Status::OK();
+ }
+ return FailedPrecondition("channel is used in disallowed computation");
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index c48a7ab0b5..3ef4542f91 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -60,6 +60,7 @@ class HloModuleGroupMetadata {
kWhileBody,
kConditionalTrue,
kConditionalFalse,
+ kCallFunction,
};
// Tracks the instruction mapped to a given computation, and the computation
@@ -202,6 +203,15 @@ class HloModuleGroupMetadata {
Status AddCompanion(HloInstruction* instruction1,
HloInstruction* instruction2);
+ // Checks whether a communicating instruction is placed in a valid position
+ // within the graph.
+ Status CheckCommunicatingInstruction(HloInstruction* instruction) const;
+
+ // Performs a consistency check on the companion sets built for the input
+ // modules. Check that a companion set does not include instructions from the
+ // same module/device.
+ Status VerifyCompanionSets() const;
+
// Retrieves a pointer to the stored TrackedInstruction associated with a
// tracked computation, or nullptr in case such computation is not tracked.
const TrackedInstruction* GetTrackedInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index ca763076a1..ac7cd2f2f5 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -74,6 +74,7 @@ namespace xla {
V(kDynamicUpdateSlice, "dynamic-update-slice") \
V(kEq, "equal-to", kHloOpcodeIsComparison) \
V(kExp, "exponential") \
+ V(kExpm1, "exponential-minus-one") \
V(kFft, "fft") \
V(kFloor, "floor") \
V(kFusion, "fusion", kHloOpcodeIsVariadic) \
@@ -87,6 +88,7 @@ namespace xla {
V(kIsFinite, "is-finite") \
V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \
V(kLog, "log") \
+ V(kLog1p, "log-plus-one") \
V(kAnd, "and") \
V(kNot, "not") \
V(kOr, "or") \
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc
index 23ace5afea..36ee7bcf84 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc
@@ -1,3 +1,5 @@
+
+
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -62,7 +64,35 @@ StatusOr<int64> MinimumMemoryForSequence(
namespace {
// Class implementing a list scheduler of HLO instructions which produces a
-// sequence which minimizes memory usage.
+// sequence which minimizes memory usage by preferring to schedule the node that
+// frees bigger buffer and defines smaller outputs.
+//
+// Note that list scheduler is a greedy algorithm which cannot guarantee a
+// global optimal solution. As a counterexample, considering the following
+// graph:
+//
+// +--> B ===> C -------+
+// A -> | |
+// | v
+// +--> D ---> F=======>G
+// | ^
+// | |
+// +--> E -----+
+//
+// --> : Buffer with size 1
+// ==> : Buffer with size 2
+//
+// The list scheduler will always try to defer scheduling B in a greedy way
+// since its output buffer is bigger than input. The sequence it creates will
+// be:
+// A D E F B C G
+// , which has a maximum memory usage of 5 (at one point, B and F will be alive
+// together).
+//
+// An optimal to shedule the previous graph will be:
+// A B C D E F G
+// , which has a maximum memory usage of 4.
+//
class ListScheduler {
public:
// Construct and return a memory-minimizing sequence of HLO instructions
@@ -366,10 +396,10 @@ StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
} // namespace
-StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
+StatusOr<std::vector<const HloInstruction*>> DFSMemorySchedulerImpl(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function) {
+ const LogicalBuffer::SizeFunction& size_function, bool reverse_heuristics) {
// This ordering is based on DFS post-order, with a heuristic to decide which
// operand to visit first. The heuristic is based on 'extra_users', which is
// simply users-1 for each instruction. By subtracting 1, we're saying that
@@ -409,19 +439,20 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
return Status::OK();
});
TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder(
- &visitor, [&extra_users, &total_sizes](const HloInstruction* a,
- const HloInstruction* b) {
- if (extra_users[a] != extra_users[b]) {
- return extra_users[a] > extra_users[b];
- }
- if (total_sizes[a] != total_sizes[b]) {
- return total_sizes[a] > total_sizes[b];
- }
- return a->name() < b->name();
+ &visitor, [&extra_users, &total_sizes, reverse_heuristics](
+ const HloInstruction* a, const HloInstruction* b) {
+ auto lhs = std::tuple<int64, int64, string>(extra_users[a],
+ total_sizes[a], b->name());
+ auto rhs = std::tuple<int64, int64, string>(extra_users[b],
+ total_sizes[b], a->name());
+
+ // Reverse heuristics. This helps some cases as a different starting
+ // point of gradient descent, see b/78906799 for more context.
+ return reverse_heuristics ? rhs > lhs : lhs > rhs;
}));
CHECK_EQ(sequence.size(), computation.instruction_count());
return sequence;
-}
+} // namespace xla
StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
const HloComputation& computation,
@@ -439,6 +470,22 @@ StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
post_order.end()};
}
+StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
+ const HloComputation& computation,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function) {
+ return DFSMemorySchedulerImpl(computation, points_to_analysis, size_function,
+ /*reverse_heuristics=*/false);
+}
+
+StatusOr<std::vector<const HloInstruction*>> DFSMemorySchedulerReverse(
+ const HloComputation& computation,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function) {
+ return DFSMemorySchedulerImpl(computation, points_to_analysis, size_function,
+ /*reverse_heuristics=*/true);
+}
+
StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
@@ -478,19 +525,34 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
VLOG(2) << "Min-memory post order sequence: "
<< HumanReadableNumBytes(post_order_memory);
- if (post_order_memory < std::min(list_memory, dfs_memory)) {
- VLOG(2) << "Chose min-memory post_order sequence: "
- << HumanReadableNumBytes(post_order_memory);
- return post_order_sequence;
+ TF_ASSIGN_OR_RETURN(std::vector<const HloInstruction*> reverse_dfs,
+ DFSMemorySchedulerReverse(computation, points_to_analysis,
+ size_function));
+ TF_ASSIGN_OR_RETURN(
+ const int64 reverse_dfs_memory,
+ MinimumMemoryForComputation(computation, reverse_dfs, points_to_analysis,
+ size_function));
+ VLOG(2) << "Min-memory reverse_dfs sequence: "
+ << HumanReadableNumBytes(reverse_dfs_memory);
+ auto min_memory = std::min(
+ {dfs_memory, post_order_memory, reverse_dfs_memory, list_memory});
- } else if (list_memory <= dfs_memory) {
+ if (min_memory == list_memory) {
VLOG(2) << "Chose min-memory list sequence: "
<< HumanReadableNumBytes(list_memory);
return list_sequence;
- } else {
+ } else if (min_memory == dfs_memory) {
VLOG(2) << "Chose min-memory dfs sequence: "
<< HumanReadableNumBytes(dfs_memory);
return dfs_sequence;
+ } else if (min_memory == reverse_dfs_memory) {
+ VLOG(2) << "Chose min-memory reverse_dfs memory: "
+ << HumanReadableNumBytes(reverse_dfs_memory);
+ return reverse_dfs;
+ } else {
+ VLOG(2) << "Chose min-memory post_order sequence: "
+ << HumanReadableNumBytes(post_order_memory);
+ return post_order_sequence;
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h
index fcb006f818..ef612414aa 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.h
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.h
@@ -61,6 +61,13 @@ StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function);
+// DFS-order scheduler with reversed heuristics. This helps some cases (see
+// b/78906799).
+StatusOr<std::vector<const HloInstruction*>> DFSMemorySchedulerReverse(
+ const HloComputation& computation,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function);
+
// The default scheduling algorithm. Runs both the list scheduler
// and the DFS scheduler, and chooses whichever returns a lower min-memory,
// not accounting for fragmentation.
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index 92df7c1427..4e956af565 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -190,5 +190,108 @@ ENTRY root {
instructions_by_name.at("e")));
}
+// The current scheduler is suboptimal, in that it does not account for the
+// memory used by subcomputations when choosing a schedule.
+// This test demonstrates the current behavior.
+// We are working on improving it (b/65409243).
+TEST_F(HloSchedulingTest, SubcomputationsNotAccounted) {
+ // %WhileCond (cond_param: f32[4]) -> pred[] {
+ // %cond_param = f32[4]{0} parameter(0)
+ // %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } })
+ // ROOT %not-equal-to = pred[] not-equal-to(
+ // f32[4]{0} %cond_param, f32[1,4]{1,0} %constant)
+ // }
+ // %WhileBody (body_param: f32[4]) -> f32[4] {
+ // %body_param = f32[4]{0} parameter(0)
+ // %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } })
+ // ROOT %subtract = f32[4]{0} subtract(
+ // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1)
+ // }
+ // %SubcomputationsNotAccounted () -> f32[2,4] {
+ // %constant.3 = f32[2,4]{1,0} constant(
+ // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } })
+ // %transpose = f32[2,4]{1,0} transpose(
+ // f32[2,4]{1,0} %constant.3), dimensions={0,1}
+ // %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } })
+ // %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2),
+ // condition=%WhileCond,
+ // body=%WhileBody
+ // %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0}
+ // ROOT %add = f32[2,4]{1,0} add(
+ // f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast)
+ // }
+
+ auto module = CreateNewModule();
+ const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
+ const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
+
+ // param != 0
+ // Needs 17 bytes
+ auto cond_builder = HloComputation::Builder("WhileCond");
+ HloInstruction* cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "cond_param"));
+ HloInstruction* zero_vector = cond_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2<float>({{0, 0, 0, 0}})));
+ cond_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
+ auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
+
+ // param - 1
+ // Needs 16 bytes
+ auto body_builder = HloComputation::Builder("WhileBody");
+ HloInstruction* body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "body_param"));
+ HloInstruction* one_vector = body_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+ body_builder.AddInstruction(HloInstruction::CreateBinary(
+ r1f32, HloOpcode::kSubtract, body_param, one_vector));
+ auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
+
+ // transpose(matrix) + bcast(while)
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* while_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+ // Creates 16 bytes, ignoring subcomputations
+ HloInstruction* while_loop =
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ r1f32, cond_computation, body_computation, while_init));
+
+ // Creates 32 bytes and frees 16
+ HloInstruction* bcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r2f32, while_loop, {0}));
+
+ HloInstruction* matrix = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2<float>(
+ {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
+ // Creates 32 bytes
+ HloInstruction* transpose = builder.AddInstruction(
+ HloInstruction::CreateTranspose(r2f32, matrix, {0, 1}));
+
+ // Creates 32 bytes and frees 64
+ HloInstruction* add = builder.AddInstruction(
+ HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast));
+
+ module->AddEntryComputation(builder.Build());
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ CreateMemoryMinimizingSequence(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+ // Verify that all instructions are in the sequence.
+ EXPECT_EQ(module->entry_computation()->instruction_count(),
+ sequence.at(module->entry_computation()).size());
+ SequentialHloOrdering ordering(module.get(), sequence);
+ // TODO(b/65409243): while_loop is scheduled first by List; it's thought to be
+ // cheaper than transpose because the temporary memory needed for
+ // subcomputations is ignored. If we count the temporary memory as part of
+ // bytes_defined, then transpose would be scheduled first. Incidentally,
+ // ignoring subcomputations results in a better schedule here.
+ EXPECT_TRUE(ordering.ExecutesBefore(while_loop, transpose));
+ EXPECT_TRUE(ordering.ExecutesBefore(bcast, transpose));
+ EXPECT_TRUE(ordering.ExecutesBefore(bcast, add));
+ EXPECT_TRUE(ordering.ExecutesBefore(transpose, add));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 096ebb7946..7d6d0d9eaf 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -106,9 +106,7 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
reduce_precision->mantissa_bits()));
}
-Status ShapeVerifier::HandleInfeed(HloInstruction*) {
- return tensorflow::Status::OK();
-}
+Status ShapeVerifier::HandleInfeed(HloInstruction*) { return Status::OK(); }
Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) {
// Outfeed has a separate shape field for the value which is outfed to the
@@ -127,12 +125,10 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) {
}
Status ShapeVerifier::HandleHostCompute(HloInstruction*) {
- return tensorflow::Status::OK();
+ return Status::OK();
}
-Status ShapeVerifier::HandleRng(HloInstruction*) {
- return tensorflow::Status::OK();
-}
+Status ShapeVerifier::HandleRng(HloInstruction*) { return Status::OK(); }
Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
return CheckShape(
@@ -164,7 +160,7 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
}
Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
- return tensorflow::Status::OK();
+ return Status::OK();
}
Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
@@ -183,7 +179,7 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
operand_shape.dimensions(operand_dimension))
<< broadcast->ToString() << " operand shape " << operand_shape;
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
@@ -191,7 +187,7 @@ Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape()));
TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
ShapeUtil::ElementsIn(reshape->operand(0)->shape()));
- return tensorflow::Status::OK();
+ return Status::OK();
}
Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
@@ -201,21 +197,17 @@ Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
}
Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
- return tensorflow::Status::OK();
+ return Status::OK();
}
-Status ShapeVerifier::HandleFusion(HloInstruction*) {
- return tensorflow::Status::OK();
-}
+Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); }
Status ShapeVerifier::HandleCall(HloInstruction* call) {
// The shape of kCall should match the shape of the computation it calls.
return CheckShape(call, call->to_apply()->ComputeProgramShape().result());
}
-Status ShapeVerifier::HandleCustomCall(HloInstruction*) {
- return tensorflow::Status::OK();
-}
+Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); }
Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
return CheckShape(slice,
@@ -497,7 +489,7 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
ShapeUtil::HumanString(instruction->shape()).c_str(),
instruction->ToString().c_str());
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
@@ -547,7 +539,7 @@ Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1,
instr1->ToString().c_str(), instr1->channel_id(),
instr2->ToString().c_str(), instr2->channel_id());
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
string ComputationsToString(
@@ -612,7 +604,7 @@ Status VerifyHloStructure(HloModule* module) {
}
}
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
@@ -728,7 +720,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
// TODO(b/65423525): We'd like to check that all operands are distinct.
// This is currently disabled due to the invariant being violated by
// multi-output fusion.
- return tensorflow::Status::OK();
+ return Status::OK();
}
Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) {
@@ -777,7 +769,7 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) {
"init: %s, body: %s",
init->ToString().c_str(), body_root->ToString().c_str());
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
@@ -795,7 +787,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
ShapeUtil::HumanString(operand_shape).c_str());
}
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
StatusOr<bool> HloVerifier::Run(HloModule* module) {
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 6208887547..1392a78097 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -82,9 +82,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
Status HandleGather(HloInstruction* gather) override;
- Status FinishVisit(HloInstruction*) override {
- return tensorflow::Status::OK();
- }
+ Status FinishVisit(HloInstruction*) override { return Status::OK(); }
protected:
// Check the instruction's shape against the shape given by ShapeInference
diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
index 13e4557317..dc3bfce0c4 100644
--- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
+++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
@@ -27,6 +27,7 @@ using tensorflow::strings::HumanReadableElapsedTime;
using tensorflow::strings::HumanReadableNumBytes;
using tensorflow::strings::Printf;
using tensorflow::strings::StrAppend;
+using tensorflow::strings::StrCat;
string HumanReadableProfileBuilder::ToString() const {
string s;
@@ -35,20 +36,26 @@ string HumanReadableProfileBuilder::ToString() const {
computation_name_.c_str(),
HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str());
- auto append_op = [&](const OpInfo& op) {
+ auto print_op = [&](const OpInfo& op) {
+ // Skip ops with 0 optimal seconds and 0 actual cycles. These are ops that
+ // were expected to be free and are actually free -- things like (on most
+ // backends) kParameter or kConstant HLOs. There's no need to clutter the
+ // profile with these.
+ if (op.optimal_seconds == 0 && op.cycles == 0) {
+ return;
+ }
+
string bytes_per_sec;
string bytes_per_cycle;
- if (op.cycles <= 0 || op.bytes_accessed < 0) {
- bytes_per_sec = "<unknown>";
- bytes_per_cycle = "<unknown>";
- } else {
- bytes_per_sec =
- HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles));
+ if (op.cycles > 0 && op.bytes_accessed >= 0) {
+ bytes_per_sec = StrCat(
+ HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles)),
+ "/s");
+ double bpc = static_cast<double>(op.bytes_accessed) / op.cycles;
if (op.bytes_accessed > op.cycles) {
- bytes_per_cycle = HumanReadableNumBytes(op.bytes_accessed / op.cycles);
+ bytes_per_cycle = StrCat(HumanReadableNumBytes(bpc), "/cycle");
} else {
- bytes_per_cycle =
- Printf("%.3fB", static_cast<float>(op.bytes_accessed) / op.cycles);
+ bytes_per_cycle = Printf("%.3fB/cycle", bpc);
}
}
@@ -59,14 +66,16 @@ string HumanReadableProfileBuilder::ToString() const {
double nsecs = op.cycles / clock_rate_ghz_;
Appendf(&s,
- "%15lld cycles (%6.2f%%) :: %12.1f usec (%12.1f optimal) :: %18s "
- ":: %18s :: %12s/s :: %12s/cycle :: %s\n",
+ "%15lld cycles (%6.2f%%) :: %12.1f usec %22s :: %18s "
+ ":: %18s :: %14s :: %16s :: %s\n",
op.cycles, cycles_percent, CyclesToMicroseconds(op.cycles),
- op.optimal_seconds * 1e6,
+ op.optimal_seconds < 0
+ ? ""
+ : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(),
op.flop_count <= 0
- ? "<none>"
+ ? ""
: HumanReadableNumFlops(op.flop_count, nsecs).c_str(),
- op.transcendental_count <= 0 ? "<none>"
+ op.transcendental_count <= 0 ? ""
: HumanReadableNumTranscendentalOps(
op.transcendental_count, nsecs)
.c_str(),
@@ -78,24 +87,26 @@ string HumanReadableProfileBuilder::ToString() const {
int64 total_transcendentals = 0.;
int64 total_bytes = 0;
for (const auto& op : op_infos_) {
- optimal_seconds_sum += op.optimal_seconds;
- total_flops += op.flop_count;
- total_transcendentals += op.transcendental_count;
- total_bytes += op.bytes_accessed;
+ if (op.optimal_seconds > 0) {
+ optimal_seconds_sum += op.optimal_seconds;
+ }
+ total_flops += std::max(op.flop_count, int64{0});
+ total_transcendentals += std::max(op.transcendental_count, int64{0});
+ total_bytes += std::max(op.bytes_accessed, int64{0});
}
VLOG(1) << "Total floating point ops: " << total_flops;
- append_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops,
- total_transcendentals, total_bytes, optimal_seconds_sum});
+ print_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops,
+ total_transcendentals, total_bytes, optimal_seconds_sum});
- // Sort ops in decreasing order of cycles.
+ // Sort ops in decreasing order of cycles, and print them.
std::vector<OpInfo> sorted_ops(op_infos_);
std::sort(
sorted_ops.begin(), sorted_ops.end(),
[](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; });
for (const auto& op : sorted_ops) {
- append_op(op);
+ print_op(op);
}
if (total_cycles_ <= 0) {
@@ -109,8 +120,20 @@ string HumanReadableProfileBuilder::ToString() const {
table.SetMetricName("microseconds above estimated optimum");
table.SetEntryName("ops");
table.SetShowCategoryTable();
+ table.SetShowAllEntries();
float total_discrepancy_in_microseconds = 0.0f;
- for (const auto& op : sorted_ops) {
+ for (const auto& op : op_infos_) {
+ // Skip ops with < 0 optimal seconds. These are ops for which we don't
+ // know the optimal time.
+ if (op.optimal_seconds < 0) {
+ continue;
+ }
+ // Also skip ops with 0 actual cycles. These ops were free; there's no
+ // need to clutter the "above estimated optimum" table with them,
+ // because they can't be optimized further.
+ if (op.cycles == 0) {
+ continue;
+ }
MetricTableReport::Entry entry;
entry.text = op.name;
entry.short_text = op.short_name;
@@ -128,7 +151,14 @@ string HumanReadableProfileBuilder::ToString() const {
table.SetMetricName("microseconds");
table.SetEntryName("ops");
table.SetShowCategoryTable();
- for (const auto& op : sorted_ops) {
+ table.SetShowAllEntries();
+ for (const auto& op : op_infos_) {
+ // Skip ops with 0 optimal seconds and 0 actual cycles. As in
+ // print_op(), these are uninteresting because they're expected to be
+ // free, and they were actually free.
+ if (op.cycles == 0 && op.optimal_seconds == 0) {
+ continue;
+ }
MetricTableReport::Entry entry;
entry.text = op.name;
entry.short_text = op.short_name;
diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h
index fb36d3a0d6..6f56c3aa82 100644
--- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h
+++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h
@@ -41,7 +41,8 @@ class HumanReadableProfileBuilder {
int64 total_cycles() const { return total_cycles_; }
// Adds an operation to the profile. If you don't know the number of
- // floating-point ops or bytes touched by the op, pass -1 for that param.
+ // floating-point ops or bytes touched by the op, or if you don't know how
+ // fast it would run optimally, pass -1 for that param.
void AddOp(tensorflow::StringPiece op_name,
tensorflow::StringPiece short_name,
tensorflow::StringPiece category, int64 cycles, int64 flop_count,
@@ -62,10 +63,10 @@ class HumanReadableProfileBuilder {
string short_name;
string category;
int64 cycles;
- int64 flop_count;
+ int64 flop_count; // -1 if unknown
int64 transcendental_count;
- int64 bytes_accessed;
- float optimal_seconds;
+ int64 bytes_accessed; // -1 if unknown
+ float optimal_seconds; // -1 if unknown
};
double CyclesToSeconds(int64 cycles) const {
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index 7aa1c7c835..d2af261008 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = Literal::CreateR1<float>({4, 3, 3, 4});
- LiteralTestUtil::ExpectEqual(*result, *expected);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
// Test that `constant` function is changed to `broadcast`.
@@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = Literal::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
- LiteralTestUtil::ExpectEqual(*result, *expected);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
@@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = Literal::CreateR1<float>({3, 1, -1, -3});
- LiteralTestUtil::ExpectEqual(*result, *expected);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 6bb2ca19fe..06b84cc145 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -120,11 +120,13 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kDivide:
case HloOpcode::kDot:
case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
case HloOpcode::kFft:
case HloOpcode::kFusion:
case HloOpcode::kGather:
case HloOpcode::kHostCompute:
case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
case HloOpcode::kMap:
case HloOpcode::kParameter:
case HloOpcode::kPower:
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
index 6dd8fa1ab0..cf9673a38a 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -92,7 +92,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
EXPECT_FALSE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
- .ValueOrDie());
+ .ValueOrDie())
+ << module->ToString();
}
// Counts the number of HLO ops with a given op code in the specified module.
@@ -151,7 +152,11 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
.Run(module.get())
.ValueOrDie())
<< module->ToString();
- EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_THAT(root->fused_expression_root(),
+ op::Subtract(op::Abs(op::Parameter()), op::Parameter()))
+ << module->ToString();
// Make sure the add hasn't been duplicated.
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
@@ -244,7 +249,12 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
.Run(module.get())
.ValueOrDie())
<< module->ToString();
- EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
+ root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_THAT(root->fused_expression_root(),
+ op::Tuple(op::Subtract(op::Parameter(), op::Parameter()),
+ op::Subtract(op::Parameter(), op::Parameter())))
+ << module->ToString();
// Make sure we didn't duplicate any adds.
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 7e1bb11eaa..986e177406 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -660,13 +660,12 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
- EXPECT_EQ(
- ::tensorflow::Status::OK(),
- backend()
- .compiler()
- ->RunBackend(std::move(module), backend().default_stream_executor(),
- /*device_allocator=*/nullptr)
- .status());
+ EXPECT_EQ(Status::OK(), backend()
+ .compiler()
+ ->RunBackend(std::move(module),
+ backend().default_stream_executor(),
+ /*device_allocator=*/nullptr)
+ .status());
}
// A GTE inside of a fusion node inherits the layout of its operand (which
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
index bc683a1880..f172b1d87c 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
@@ -151,7 +151,7 @@ Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) {
Status FusedIrEmitter::FinishVisit(HloInstruction* root) {
fused_root_ = root;
- return tensorflow::Status::OK();
+ return Status::OK();
}
FusedIrEmitter::Generator FusedIrEmitter::GetRootGenerator() const {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
index 3978acc132..0728ccfff7 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
@@ -39,14 +39,13 @@ LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape,
LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
const IrArray& target_array,
llvm::IRBuilder<>* ir_builder)
- : body_emitter_([=](const llvm_ir::IrArray::Index array_index)
- -> ::tensorflow::Status {
+ : body_emitter_([=](const llvm_ir::IrArray::Index array_index) -> Status {
// Convert target_element_generator to a BodyEmitter.
TF_ASSIGN_OR_RETURN(llvm::Value * target_element,
target_element_generator(array_index));
target_array.EmitWriteArrayElement(array_index, target_element,
ir_builder);
- return tensorflow::Status::OK();
+ return Status::OK();
}),
shape_(target_array.GetShape()),
ir_builder_(ir_builder) {}
@@ -124,7 +123,7 @@ std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
return {array_index};
}
-tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) {
+Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) {
for (const IrArray::Index& array_index :
EmitIndexAndSetExitBasicBlock(loop_name)) {
TF_RETURN_IF_ERROR(body_emitter_(array_index));
@@ -135,7 +134,7 @@ tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) {
if (exit_bb_ != nullptr) {
ir_builder_->SetInsertPoint(exit_bb_);
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
} // namespace llvm_ir
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
index 9ff497aecd..b70d28ecd3 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
@@ -38,8 +38,7 @@ using ElementGenerator =
// Emits a loop for every element in the given shape.
class LoopEmitter {
public:
- using BodyEmitter =
- std::function<tensorflow::Status(const IrArray::Index& index)>;
+ using BodyEmitter = std::function<Status(const IrArray::Index& index)>;
LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape,
llvm::IRBuilder<>* ir_builder);
@@ -72,7 +71,7 @@ class LoopEmitter {
tensorflow::StringPiece loop_name);
// Emits a complete loop nest for every element in the given shape.
- tensorflow::Status EmitLoop(tensorflow::StringPiece loop_name = "");
+ Status EmitLoop(tensorflow::StringPiece loop_name = "");
protected:
// An IR emitter that generates the loop body.
diff --git a/tensorflow/compiler/xla/service/owning_device_memory.cc b/tensorflow/compiler/xla/service/owning_device_memory.cc
new file mode 100644
index 0000000000..c115bc097f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/owning_device_memory.cc
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/owning_device_memory.h"
+
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+
+namespace xla {
+
+void OwningDeviceMemory::Free() {
+ CHECK(allocator_ != nullptr)
+ << "Can't call Free() on an inactive (i.e. moved from, Forget()'ten, "
+ "or Free()'ed) instance.";
+ auto status = allocator_->Deallocate(device_ordinal_, mem_);
+ if (!status.ok()) {
+ LOG(WARNING) << "Deallocating buffer " << mem_.opaque() << " failed.";
+ }
+
+ allocator_ = nullptr;
+ mem_ = se::DeviceMemoryBase();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/owning_device_memory.h b/tensorflow/compiler/xla/service/owning_device_memory.h
new file mode 100644
index 0000000000..9cf071f0d9
--- /dev/null
+++ b/tensorflow/compiler/xla/service/owning_device_memory.h
@@ -0,0 +1,131 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+
+// Break circular dependency between this file and device_memory_allocator.h.
+class DeviceMemoryAllocator;
+
+// Owning pointer for memory on a device.
+//
+// OwningDeviceMemory is an owning pointer like std::unique_ptr, but it can
+// point to memory that resides on a "device" (e.g. a GPU). When an
+// OwningDeviceMemory goes out of scope, it frees the memory it owns.
+//
+// We say that an instance of OwningDeviceMemory is "active" if it currently
+// owns a (possibly empty) slice of memory on the device. Moving, Forget()'ing,
+// Free()'ing, and other actions can deactive an active object.
+//
+// Note that we can't simply use stream_executor::ScopedDeviceMemory instead of
+// OwningDeviceMemory, because ScopedDeviceMemory frees its pointer via a
+// StreamExecutor. This class needs to free via a xla::DeviceMemoryAllocator.
+class OwningDeviceMemory {
+ public:
+ OwningDeviceMemory() : device_ordinal_(-1), allocator_(nullptr) {}
+
+ explicit OwningDeviceMemory(se::DeviceMemoryBase mem, int device_ordinal,
+ DeviceMemoryAllocator* allocator)
+ : mem_(mem), device_ordinal_(device_ordinal), allocator_(allocator) {
+ CHECK(allocator != nullptr) << "allocator cannot be null.";
+ }
+
+ OwningDeviceMemory(OwningDeviceMemory&& other)
+ : mem_(other.mem_),
+ device_ordinal_(other.device_ordinal_),
+ allocator_(other.allocator_) {
+ other.mem_ = se::DeviceMemoryBase();
+ other.allocator_ = nullptr;
+ }
+
+ OwningDeviceMemory& operator=(OwningDeviceMemory&& other) {
+ if (allocator_ != nullptr) {
+ Free();
+ }
+ mem_ = other.mem_;
+ device_ordinal_ = other.device_ordinal_;
+ allocator_ = other.allocator_;
+
+ other.mem_ = se::DeviceMemoryBase();
+ other.allocator_ = nullptr;
+ return *this;
+ }
+
+ // Deactivates this instance if it's active. Nop if it's not active.
+ OwningDeviceMemory& operator=(std::nullptr_t) {
+ if (allocator_ != nullptr) {
+ Free();
+ }
+ return *this;
+ }
+
+ ~OwningDeviceMemory() {
+ if (allocator_ != nullptr) {
+ Free();
+ }
+ }
+
+ // The returned allocator is nonnull iff this object is active.
+ DeviceMemoryAllocator* allocator() const { return allocator_; }
+
+ int device_ordinal() const { return device_ordinal_; }
+
+ // Gets the device memory pointer.
+ const void* opaque() const { return mem_.opaque(); }
+ void* opaque() { return mem_.opaque(); }
+
+ uint64 size() const { return mem_.size(); }
+
+ // Determines whether this wraps a null pointer.
+ //
+ // !is_null() is sufficient but not necessary to imply `this` is active.
+ bool is_null() const { return mem_.is_null(); }
+
+ se::DeviceMemoryBase AsDeviceMemoryBase() {
+ return se::DeviceMemoryBase(opaque(), size(), /*is_sub_buffer=*/false);
+ }
+
+ // Returns the wrapped DeviceMemoryBase without freeing it, and deactivates
+ // this object. Precondition: `this` is active.
+ TF_MUST_USE_RESULT se::DeviceMemoryBase Forget() {
+ CHECK(allocator_ != nullptr)
+ << "Can't call Forget() on an inactive (i.e. moved from, Forget()'ten, "
+ "or Free()'ed) instance.";
+ allocator_ = nullptr;
+ se::DeviceMemoryBase mem(mem_);
+ mem_ = se::DeviceMemoryBase();
+ return mem;
+ }
+
+ // Frees the wrapped DeviceMemoryBase and deactivates this object.
+ // Precondition: `this` is active.
+ void Free();
+
+ private:
+ se::DeviceMemoryBase mem_;
+ int device_ordinal_;
+ DeviceMemoryAllocator* allocator_; // Null if this object is inactive.
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 495f8801ba..047cadb3d9 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -64,7 +64,7 @@ namespace {
// Records the arguments used to invoke a computation in a SessionModule
// proto.
-tensorflow::Status RecordArguments(
+Status RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
se::StreamExecutor* executor, TransferManager* transfer_manager,
SessionModule* module) {
@@ -75,24 +75,22 @@ tensorflow::Status RecordArguments(
transfer_manager->TransferLiteralFromDevice(executor, *argument));
*module->add_arguments() = literal->ToProto();
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
// Records the result of a computation in a SessionModule proto.
-tensorflow::Status RecordResult(const ShapedBuffer& result,
- se::StreamExecutor* executor,
- TransferManager* transfer_manager,
- SessionModule* module) {
+Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor,
+ TransferManager* transfer_manager, SessionModule* module) {
module->clear_result();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Literal> literal,
transfer_manager->TransferLiteralFromDevice(executor, result));
*module->mutable_result() = literal->ToProto();
- return tensorflow::Status::OK();
+ return Status::OK();
}
// Records the arguments used to invoke a computation in an HloSnapshot proto.
-tensorflow::Status RecordArguments(
+Status RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
se::StreamExecutor* executor, TransferManager* transfer_manager,
HloSnapshot* module) {
@@ -103,20 +101,18 @@ tensorflow::Status RecordArguments(
transfer_manager->TransferLiteralFromDevice(executor, *argument));
*module->add_arguments() = literal->ToProto();
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
// Records the result of a computation in a HloSnapshot proto.
-tensorflow::Status RecordResult(const ShapedBuffer& result,
- se::StreamExecutor* executor,
- TransferManager* transfer_manager,
- HloSnapshot* module) {
+Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor,
+ TransferManager* transfer_manager, HloSnapshot* module) {
module->clear_result();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Literal> literal,
transfer_manager->TransferLiteralFromDevice(executor, result));
*module->mutable_result() = literal->ToProto();
- return tensorflow::Status::OK();
+ return Status::OK();
}
} // namespace
@@ -199,8 +195,8 @@ Service::Service(const ServiceOptions& options,
}
}
-tensorflow::Status Service::Computation(const ComputationRequest* arg,
- ComputationResponse* result) {
+Status Service::Computation(const ComputationRequest* arg,
+ ComputationResponse* result) {
if (arg->name().empty()) {
return InvalidArgument("computation request needs a name");
}
@@ -210,24 +206,23 @@ tensorflow::Status Service::Computation(const ComputationRequest* arg,
VLOG(1) << Printf("Created new computation %s on service %p, name %s",
result->computation().ShortDebugString().c_str(), this,
arg->name().c_str());
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::CreateChannelHandle(
- const CreateChannelHandleRequest* arg,
- CreateChannelHandleResponse* result) {
+Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg,
+ CreateChannelHandleResponse* result) {
*result->mutable_channel() = channel_tracker_.NewChannel();
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::Unregister(const UnregisterRequest* arg,
- UnregisterResponse* result) {
+Status Service::Unregister(const UnregisterRequest* arg,
+ UnregisterResponse* result) {
return allocation_tracker_.Unregister(arg->data());
}
// Deconstructs a previously-allocated global handle.
-tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg,
- DeconstructTupleResponse* result) {
+Status Service::DeconstructTuple(const DeconstructTupleRequest* arg,
+ DeconstructTupleResponse* result) {
TF_ASSIGN_OR_RETURN(
std::vector<GlobalDataHandle> elements,
allocation_tracker_.DeconstructTuple(arg->tuple_handle()));
@@ -235,11 +230,11 @@ tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg,
for (auto& element : elements) {
*result->add_element_handles() = element;
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::ValidateResultShapeWithLayout(
- const Shape& shape_with_layout, const Shape& result_shape) const {
+Status Service::ValidateResultShapeWithLayout(const Shape& shape_with_layout,
+ const Shape& result_shape) const {
if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) {
return InvalidArgument(
"Shape used to set computation result layout %s is not compatible "
@@ -511,7 +506,7 @@ Status Service::ValidateEntryComputationLayout(HloModule* module) {
module->device_entry_computation_layout().result_shape(),
execute_backend_->transfer_manager()->HostShapeToDeviceShape(
module->host_entry_computation_layout().result_shape())));
- return tensorflow::Status::OK();
+ return Status::OK();
}
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
@@ -801,8 +796,8 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
result_tag);
}
-tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg,
- SetReturnValueResponse* results) {
+Status Service::SetReturnValue(const SetReturnValueRequest* arg,
+ SetReturnValueResponse* results) {
TF_ASSIGN_OR_RETURN(UserComputation * computation,
computation_tracker_.Resolve(arg->computation()));
return computation->SetReturnValue(arg->operand());
@@ -849,8 +844,8 @@ StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
return replicated_arguments;
}
-tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
- ExecuteParallelResponse* result) {
+Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
+ ExecuteParallelResponse* result) {
VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString();
std::vector<std::vector<std::vector<const ShapedBuffer*>>> all_arguments;
@@ -957,11 +952,11 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
}
VLOG(1) << "successfully completed 'execute-parallel' request";
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::ExecuteGraphParallel(
- const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) {
+Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
+ ExecuteParallelResponse* result) {
VLOG(1) << "running execute-graph-parallel request";
std::vector<std::vector<std::vector<const ShapedBuffer*>>> all_arguments;
@@ -1058,11 +1053,11 @@ tensorflow::Status Service::ExecuteGraphParallel(
}
VLOG(1) << "successfully completed 'execute-graph-parallel' request";
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
- GetDeviceHandlesResponse* result) {
+Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
+ GetDeviceHandlesResponse* result) {
const int64 available_device_count = execute_backend_->device_count();
const int64 replica_count = options_.number_of_replicas();
if (replica_count <= 0) {
@@ -1082,11 +1077,11 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
*result->add_device_handles() = device_handle;
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg,
- ExecuteResponse* result) {
+Status Service::ExecuteOneToN(const ExecuteRequest* arg,
+ ExecuteResponse* result) {
ExecuteParallelRequest parallel_arg;
*parallel_arg.add_requests() = *arg;
ExecuteParallelResponse parallel_result;
@@ -1094,8 +1089,8 @@ tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg,
return PickParallelResponse(parallel_result, result);
}
-tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg,
- ExecuteResponse* result) {
+Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg,
+ ExecuteResponse* result) {
ExecuteGraphParallelRequest parallel_arg;
*parallel_arg.add_requests() = *arg;
ExecuteParallelResponse parallel_result;
@@ -1103,7 +1098,7 @@ tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg,
return PickParallelResponse(parallel_result, result);
}
-tensorflow::Status Service::PickParallelResponse(
+Status Service::PickParallelResponse(
const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) {
// The "result device" selection is a bit hacky, but better than assuming it
// is device 0. We have b/76035356 for restructuring the client API to clean
@@ -1126,8 +1121,7 @@ tensorflow::Status Service::PickParallelResponse(
return Status::OK();
}
-tensorflow::Status Service::Execute(const ExecuteRequest* arg,
- ExecuteResponse* result) {
+Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) {
VLOG(1) << "running execute request: " << arg->ShortDebugString();
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
@@ -1198,7 +1192,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
}
VLOG(1) << "successfully completed 'execute' request";
- return tensorflow::Status::OK();
+ return Status::OK();
}
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
@@ -1243,8 +1237,8 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
return std::move(executable);
}
-tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
- ExecuteResponse* result) {
+Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
+ ExecuteResponse* result) {
VLOG(1) << "running execute-graph request";
if (!arg->has_computation()) {
@@ -1303,11 +1297,11 @@ tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
}
VLOG(1) << "successfully completed 'execute-graph' request";
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
- ExecuteAsyncResponse* result) {
+Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
+ ExecuteAsyncResponse* result) {
VLOG(1) << "running execute-async request: " << arg->ShortDebugString();
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
@@ -1383,11 +1377,11 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
streams.clear();
VLOG(1) << "successfully completed 'execute-async' request";
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg,
- WaitForExecutionResponse* result) {
+Status Service::WaitForExecution(const WaitForExecutionRequest* arg,
+ WaitForExecutionResponse* result) {
TF_ASSIGN_OR_RETURN(const auto execution,
execution_tracker_.Resolve(arg->execution()));
@@ -1398,11 +1392,11 @@ tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg,
TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution()));
VLOG(1) << "successfully completed 'wait-for-execution' request";
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg,
- TransferToClientResponse* result) {
+Status Service::TransferToClient(const TransferToClientRequest* arg,
+ TransferToClientResponse* result) {
TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer,
allocation_tracker_.ResolveForReplica(arg->data(), 0));
@@ -1432,7 +1426,7 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg,
*result->mutable_literal() =
result_literal->Relayout(*return_shape)->ToProto();
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
namespace {
@@ -1450,8 +1444,8 @@ std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice(
} // namespace
-tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
- TransferToServerResponse* result) {
+Status Service::TransferToServer(const TransferToServerRequest* arg,
+ TransferToServerResponse* result) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
Literal::CreateFromProto(arg->literal()));
const Shape& shape = literal->shape();
@@ -1484,11 +1478,11 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
StrCat("TransferToServer literal of shape ",
ShapeUtil::HumanString(shape))));
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
- TransferToInfeedResponse* result) {
+Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
+ TransferToInfeedResponse* result) {
const int64 replica_count = options_.number_of_replicas();
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
return FailedPrecondition(
@@ -1517,9 +1511,8 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
executor, *literal);
}
-tensorflow::Status Service::TransferFromOutfeed(
- const TransferFromOutfeedRequest* arg,
- TransferFromOutfeedResponse* result) {
+Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
+ TransferFromOutfeedResponse* result) {
const int64 replica_count = options_.number_of_replicas();
if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
return FailedPrecondition(
@@ -1545,16 +1538,16 @@ tensorflow::Status Service::TransferFromOutfeed(
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
executor, arg->shape_with_layout(), &literal));
*result->mutable_literal() = literal.ToProto();
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
- ResetDeviceResponse* result) {
+Status Service::ResetDevice(const ResetDeviceRequest* arg,
+ ResetDeviceResponse* result) {
return execute_backend_->ResetDevices();
}
-tensorflow::Status Service::IsConstant(const IsConstantRequest* arg,
- IsConstantResponse* result) {
+Status Service::IsConstant(const IsConstantRequest* arg,
+ IsConstantResponse* result) {
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
computation_tracker_.Resolve(arg->computation()));
@@ -1570,11 +1563,11 @@ tensorflow::Status Service::IsConstant(const IsConstantRequest* arg,
user_computation->IsConstant(arg->operand(), arg->num_parameters()));
result->set_is_constant(is_constant);
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
- ComputeConstantResponse* result) {
+Status Service::ComputeConstant(const ComputeConstantRequest* arg,
+ ComputeConstantResponse* result) {
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
computation_tracker_.Resolve(arg->computation()));
@@ -1661,11 +1654,11 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
}
*result->mutable_literal() = result_literal->ToProto();
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::ComputeConstantGraph(
- const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) {
+Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
+ ComputeConstantResponse* result) {
if (!arg->has_computation()) {
return InvalidArgument("computations may not be empty");
}
@@ -1703,20 +1696,18 @@ tensorflow::Status Service::ComputeConstantGraph(
}
*result->mutable_literal() = result_literal->ToProto();
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::GetShape(const GetShapeRequest* arg,
- GetShapeResponse* result) {
+Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) {
TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
allocation_tracker_.ResolveForReplica(arg->data(), 0));
*result->mutable_shape() = buffer->on_host_shape();
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::GetComputationShape(
- const GetComputationShapeRequest* arg,
- GetComputationShapeResponse* result) {
+Status Service::GetComputationShape(const GetComputationShapeRequest* arg,
+ GetComputationShapeResponse* result) {
TF_ASSIGN_OR_RETURN(UserComputation * computation,
computation_tracker_.Resolve(arg->computation()));
@@ -1726,21 +1717,21 @@ tensorflow::Status Service::GetComputationShape(
TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape(
versioned_handle.version));
*result->mutable_program_shape() = *program_shape;
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::GetLocalShape(const GetLocalShapeRequest* arg,
- GetLocalShapeResponse* result) {
+Status Service::GetLocalShape(const GetLocalShapeRequest* arg,
+ GetLocalShapeResponse* result) {
TF_ASSIGN_OR_RETURN(UserComputation * computation,
computation_tracker_.Resolve(arg->computation()));
TF_ASSIGN_OR_RETURN(*result->mutable_shape(),
computation->GetShape(arg->operand()));
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::GetComputationStats(
- const ComputationStatsRequest* arg, ComputationStatsResponse* result) {
+Status Service::GetComputationStats(const ComputationStatsRequest* arg,
+ ComputationStatsResponse* result) {
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
computation_tracker_.Resolve(arg->computation()));
@@ -1766,10 +1757,10 @@ tensorflow::Status Service::GetComputationStats(
stats.set_flop_count(analysis.flop_count());
stats.set_transcendental_count(analysis.transcendental_count());
*result->mutable_stats() = stats;
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::GetComputationGraphStats(
+Status Service::GetComputationGraphStats(
const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) {
if (!arg->has_computation()) {
return InvalidArgument("Computations may not be empty.");
@@ -1796,11 +1787,11 @@ tensorflow::Status Service::GetComputationGraphStats(
stats.set_flop_count(analysis.flop_count());
stats.set_transcendental_count(analysis.transcendental_count());
*result->mutable_stats() = stats;
- return tensorflow::Status::OK();
+ return Status::OK();
}
template <typename RequestT, typename ResponseT>
-tensorflow::Status Service::AddInstruction(
+Status Service::AddInstruction(
const RequestT* arg, ResponseT* result,
const std::function<StatusOr<ComputationDataHandle>(UserComputation*)>&
adder) {
@@ -1808,10 +1799,10 @@ tensorflow::Status Service::AddInstruction(
computation_tracker_.Resolve(arg->computation()));
TF_ASSIGN_OR_RETURN(*result->mutable_output(), adder(computation));
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
+Status Service::Op(const OpRequest* arg, OpResponse* result) {
TF_ASSIGN_OR_RETURN(UserComputation * computation,
computation_tracker_.Resolve(arg->computation()));
StatusOr<ComputationDataHandle> handle_status;
@@ -2033,27 +2024,26 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
if (arg->has_sharding()) {
TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding()));
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::SnapshotComputation(
- const SnapshotComputationRequest* arg,
- SnapshotComputationResponse* result) {
+Status Service::SnapshotComputation(const SnapshotComputationRequest* arg,
+ SnapshotComputationResponse* result) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<SessionModule> module,
computation_tracker_.SnapshotComputation(arg->computation()));
result->set_allocated_module(module.release());
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status Service::LoadComputationSnapshot(
+Status Service::LoadComputationSnapshot(
const LoadComputationSnapshotRequest* arg,
LoadComputationSnapshotResponse* result) {
TF_ASSIGN_OR_RETURN(*result->mutable_computation(),
computation_tracker_.LoadSessionModule(arg->module()));
- return tensorflow::Status::OK();
+ return Status::OK();
}
DeviceHandle Service::SingleComputationDeviceHandle() const {
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index f84fe407e0..81fbd41957 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -85,55 +85,52 @@ class Service : public ServiceInterface {
// Creates a new computation with the given name.
// A unique ComputationHandle is returned.
- tensorflow::Status Computation(const ComputationRequest* arg,
- ComputationResponse* result) override;
+ Status Computation(const ComputationRequest* arg,
+ ComputationResponse* result) override;
// Unregisters a previously-allocated global handle.
//
// If the handle given is not currently allocated, a NOT_FOUND status is
// returned.
- tensorflow::Status Unregister(const UnregisterRequest* arg,
- UnregisterResponse* result) override;
+ Status Unregister(const UnregisterRequest* arg,
+ UnregisterResponse* result) override;
// Deconstructs a tuple. Returns a newly created GlobalDataHandle for each
// element in the tuple.
- tensorflow::Status DeconstructTuple(
- const DeconstructTupleRequest* arg,
- DeconstructTupleResponse* result) override;
+ Status DeconstructTuple(const DeconstructTupleRequest* arg,
+ DeconstructTupleResponse* result) override;
// Modifies the provided computation so that subsequent executions
// will compute the provided ComputationDataHandle, rather than the
// last expression enqueued on that Computation.
- tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg,
- SetReturnValueResponse* results) override;
+ Status SetReturnValue(const SetReturnValueRequest* arg,
+ SetReturnValueResponse* results) override;
// Executes a computation with the provided global data passed as
// immutable arguments. Returns global data output and execution timing.
- tensorflow::Status Execute(const ExecuteRequest* arg,
- ExecuteResponse* result) override;
+ Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override;
// Executes a computation with the provided global data passed as
// immutable arguments. The request contains the whole computation graph.
// Returns global data output and execution timing.
//
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
- tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg,
- ExecuteResponse* result) override;
+ Status ExecuteGraph(const ExecuteGraphRequest* arg,
+ ExecuteResponse* result) override;
// Executes one or more computations in parallel with the provided global data
// passed as immutable arguments. Returns global data output for each
// computation.
- tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg,
- ExecuteParallelResponse* result) override;
+ Status ExecuteParallel(const ExecuteParallelRequest* arg,
+ ExecuteParallelResponse* result) override;
// Executes one or more computations in parallel with the provided global data
// passed as immutable arguments. Returns global data output for each
// computation.
//
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
- tensorflow::Status ExecuteGraphParallel(
- const ExecuteGraphParallelRequest* arg,
- ExecuteParallelResponse* result) override;
+ Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
+ ExecuteParallelResponse* result) override;
// Requests one or more device handles from the target.
//
@@ -143,9 +140,8 @@ class Service : public ServiceInterface {
// the first set of replicas, and the next R devices to the second set of
// replicas, etc. Each returned device handle represents the device with the
// replica id 0.
- tensorflow::Status GetDeviceHandles(
- const GetDeviceHandlesRequest* arg,
- GetDeviceHandlesResponse* result) override;
+ Status GetDeviceHandles(const GetDeviceHandlesRequest* arg,
+ GetDeviceHandlesResponse* result) override;
// Asynchronously executes a computation with provided arguments. Invokes
// the provided computation with the provided global data passed as
@@ -154,38 +150,33 @@ class Service : public ServiceInterface {
// (Note: The corresponding function in xla::Client was removed as part of
// b/64116060, in an attempt to simplify our API. We're keeping this around
// for now in case we want to expose this to clients in a different way.)
- tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg,
- ExecuteAsyncResponse* result) override;
+ Status ExecuteAsync(const ExecuteAsyncRequest* arg,
+ ExecuteAsyncResponse* result) override;
// Waits until the specified execution is complete and returns the result.
// Calling this API multiple times with the same execution handle returns the
// method with an error since the execution handle is destroyed after the
// first call.
- tensorflow::Status WaitForExecution(
- const WaitForExecutionRequest* arg,
- WaitForExecutionResponse* result) override;
+ Status WaitForExecution(const WaitForExecutionRequest* arg,
+ WaitForExecutionResponse* result) override;
// Requests that global data be transferred to the client in literal form.
- tensorflow::Status TransferToClient(
- const TransferToClientRequest* arg,
- TransferToClientResponse* result) override;
+ Status TransferToClient(const TransferToClientRequest* arg,
+ TransferToClientResponse* result) override;
// Transfers data from a literal provided by the client, into device memory.
- tensorflow::Status TransferToServer(
- const TransferToServerRequest* arg,
- TransferToServerResponse* result) override;
+ Status TransferToServer(const TransferToServerRequest* arg,
+ TransferToServerResponse* result) override;
// Transfers data from a literal provided by the client, into the Infeed
// buffer of the device.
- tensorflow::Status TransferToInfeed(
- const TransferToInfeedRequest* arg,
- TransferToInfeedResponse* result) override;
+ Status TransferToInfeed(const TransferToInfeedRequest* arg,
+ TransferToInfeedResponse* result) override;
// Transfers data from the Outfeed othe device to the literal provided by the
// client.
- tensorflow::Status TransferFromOutfeed(
- const TransferFromOutfeedRequest* arg,
- TransferFromOutfeedResponse* result) override;
+ Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
+ TransferFromOutfeedResponse* result) override;
// Resets devices, clearing all existing state on all the devices associated
// with this service (including memory allocated on the devices).
@@ -196,71 +187,65 @@ class Service : public ServiceInterface {
// ResetDevice should be called before an Execution that expect the device to
// be in the reset state. For example, if the prior Execution modifies device
// state (e.g., architectural state) that the next Execution depends on.
- tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
- ResetDeviceResponse* result) override;
+ Status ResetDevice(const ResetDeviceRequest* arg,
+ ResetDeviceResponse* result) override;
// Tests if an expression is a compile-time constant.
- tensorflow::Status IsConstant(const IsConstantRequest* arg,
- IsConstantResponse* result) override;
+ Status IsConstant(const IsConstantRequest* arg,
+ IsConstantResponse* result) override;
// Computes the value of a constant expression.
- tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg,
- ComputeConstantResponse* result) override;
- tensorflow::Status ComputeConstantGraph(
- const ComputeConstantGraphRequest* arg,
- ComputeConstantResponse* result) override;
+ Status ComputeConstant(const ComputeConstantRequest* arg,
+ ComputeConstantResponse* result) override;
+ Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
+ ComputeConstantResponse* result) override;
// Returns the shape (with layout) of an array associated with a given data
// handle.
- tensorflow::Status GetShape(const GetShapeRequest* arg,
- GetShapeResponse* result) override;
+ Status GetShape(const GetShapeRequest* arg,
+ GetShapeResponse* result) override;
// Returns the program shape of the computation associated with the given
// handle.
- tensorflow::Status GetComputationShape(
- const GetComputationShapeRequest* arg,
- GetComputationShapeResponse* result) override;
+ Status GetComputationShape(const GetComputationShapeRequest* arg,
+ GetComputationShapeResponse* result) override;
/////
// Computation-oriented methods.
// Enqueues an Op on the computation.
- tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override;
+ Status Op(const OpRequest* arg, OpResponse* result) override;
// Retrieves the inferred shape for a value within a computation.
- tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg,
- GetLocalShapeResponse* result) override;
+ Status GetLocalShape(const GetLocalShapeRequest* arg,
+ GetLocalShapeResponse* result) override;
// Retrieves the statistics of a computation.
- tensorflow::Status GetComputationStats(
- const ComputationStatsRequest* arg,
- ComputationStatsResponse* result) override;
+ Status GetComputationStats(const ComputationStatsRequest* arg,
+ ComputationStatsResponse* result) override;
// Retrieves the statistics of a computation.
//
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
- tensorflow::Status GetComputationGraphStats(
- const ComputationGraphStatsRequest* arg,
- ComputationStatsResponse* result) override;
+ Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg,
+ ComputationStatsResponse* result) override;
// Snapshots the current state of a computation handle into a serializable
// protocol buffer form, so it can be loaded via
// LoadComputationSnapshot.
- tensorflow::Status SnapshotComputation(
- const SnapshotComputationRequest* arg,
- SnapshotComputationResponse* result) override;
+ Status SnapshotComputation(const SnapshotComputationRequest* arg,
+ SnapshotComputationResponse* result) override;
// Loads a computation from a serialized protocol buffer created via
// SnapshotComputation.
- tensorflow::Status LoadComputationSnapshot(
+ Status LoadComputationSnapshot(
const LoadComputationSnapshotRequest* arg,
LoadComputationSnapshotResponse* result) override;
// Creates a unique channel handle that can be used for Send/Recv
// instructions.
- tensorflow::Status CreateChannelHandle(
- const CreateChannelHandleRequest* arg,
- CreateChannelHandleResponse* result) override;
+ Status CreateChannelHandle(const CreateChannelHandleRequest* arg,
+ CreateChannelHandleResponse* result) override;
// Returns the ComputationTracker of the current service instance.
// Only used in unit tests to access user computations from client.
@@ -389,7 +374,7 @@ class Service : public ServiceInterface {
// Convenience function for adding a function to a user computation.
template <typename RequestT, typename ResponseT>
- tensorflow::Status AddInstruction(
+ Status AddInstruction(
const RequestT* arg, ResponseT* result,
const std::function<StatusOr<ComputationDataHandle>(UserComputation*)>&
adder);
@@ -397,16 +382,14 @@ class Service : public ServiceInterface {
// Executes a single computation which has more than one target device.
// The N devices are expected to all return an empty tuple, but one, which
// will be the result of this computation.
- tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg,
- ExecuteResponse* result);
- tensorflow::Status ExecuteOneToN(const ExecuteGraphRequest* arg,
- ExecuteResponse* result);
+ Status ExecuteOneToN(const ExecuteRequest* arg, ExecuteResponse* result);
+ Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result);
// Convenience function which checks whether the given shape_with_layout
// (presumably passed by the client to set the result layout) is valid for the
// given computation result shape.
- tensorflow::Status ValidateResultShapeWithLayout(
- const Shape& shape_with_layout, const Shape& result_shape) const;
+ Status ValidateResultShapeWithLayout(const Shape& shape_with_layout,
+ const Shape& result_shape) const;
// Returns the stream executors assigned to the replicas represented by the
// given device handle. Each device_handle is a virtual replicated device that
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index c493547d9e..3500978bdd 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -58,6 +58,8 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
return UNOP_COS;
case HloOpcode::kExp:
return UNOP_EXP;
+ case HloOpcode::kExpm1:
+ return UNOP_EXPM1;
case HloOpcode::kFloor:
return UNOP_FLOOR;
case HloOpcode::kImag:
@@ -66,6 +68,8 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
return UNOP_IS_FINITE;
case HloOpcode::kLog:
return UNOP_LOG;
+ case HloOpcode::kLog1p:
+ return UNOP_LOG1P;
case HloOpcode::kNot:
return UNOP_NOT;
case HloOpcode::kNegate:
@@ -168,8 +172,8 @@ bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
}
-tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape,
- tensorflow::StringPiece op_type) {
+Status ExpectNotTupleOrOpaque(const Shape& shape,
+ tensorflow::StringPiece op_type) {
if (ShapeUtil::IsTuple(shape)) {
return InvalidArgument("Expected non-tuple argument for %s, but got %s.",
std::string(op_type).c_str(),
@@ -179,13 +183,13 @@ tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape,
std::string(op_type).c_str(),
ShapeUtil::HumanString(shape).c_str());
} else {
- return tensorflow::Status::OK();
+ return Status::OK();
}
}
-tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
- const Shape& init_value_shape,
- const PrimitiveType& input_element_type) {
+Status VerifyReducerShape(const ProgramShape& reducer_shape,
+ const Shape& init_value_shape,
+ const PrimitiveType& input_element_type) {
if (reducer_shape.parameters_size() != 2) {
return InvalidArgument(
"Reduction function must take 2 parameters, but "
@@ -245,7 +249,7 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
ShapeUtil::HumanString(accumulator_shape).c_str());
}
- return tensorflow::Status::OK();
+ return Status::OK();
}
StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
@@ -337,7 +341,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
case UNOP_COS:
case UNOP_SIN:
case UNOP_EXP:
+ case UNOP_EXPM1:
case UNOP_LOG:
+ case UNOP_LOG1P:
case UNOP_TANH:
if (!ShapeUtil::ElementIsFloating(arg) &&
!ShapeUtil::ElementIsComplex(arg)) {
@@ -1212,11 +1218,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
scale_shape, "scale input of batch norm training"));
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
- tensorflow::Status::OK());
+ Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
- tensorflow::Status::OK());
+ Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
- tensorflow::Status::OK());
+ Status::OK());
if (feature_index >= ShapeUtil::Rank(operand_shape)) {
return InvalidArgument(
@@ -1318,15 +1324,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
scale_shape, "scale input of batch norm inference"));
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
- tensorflow::Status::OK());
+ Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) ==
- tensorflow::Status::OK());
+ Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) ==
- tensorflow::Status::OK());
+ Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape) ==
- tensorflow::Status::OK());
+ Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) ==
- tensorflow::Status::OK());
+ Status::OK());
if (feature_index >= ShapeUtil::Rank(operand_shape)) {
return InvalidArgument(
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index fb3b5f06da..6bacb37206 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
-#include <set>
#include <string>
#include <utility>
@@ -25,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
@@ -138,14 +138,12 @@ ScopedShapedBuffer::~ScopedShapedBuffer() {
// Deallocate all non-null buffers. A buffer may appear in more than one spot
// in the shape (eg, a tuple with a repeated element) so keep track of what
// has been deallocated.
- std::set<void*> deallocated_opaques;
+ tensorflow::gtl::FlatSet<void*> deallocated_ptrs;
for (auto& pair : buffers_) {
se::DeviceMemoryBase& memory_base = pair.second;
if (!memory_base.is_null() &&
- deallocated_opaques.count(memory_base.opaque()) == 0) {
- deallocated_opaques.insert(memory_base.opaque());
- TF_CHECK_OK(
- this->allocator_->Deallocate(this->device_ordinal(), &memory_base));
+ deallocated_ptrs.insert(memory_base.opaque()).second) {
+ TF_CHECK_OK(allocator_->Deallocate(device_ordinal(), memory_base));
}
}
}
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h
index e10fca9e94..25b709523b 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.h
+++ b/tensorflow/compiler/xla/service/shaped_buffer.h
@@ -148,11 +148,25 @@ class ScopedShapedBuffer : public ShapedBuffer {
// ScopedShapedBuffer.
DeviceMemoryAllocator* memory_allocator() const { return allocator_; }
- // Releases all device memory owned by this ScopedShapedBuffer and returns the
- // device memory pointers in the form of a ShapedBuffer. The returned
- // ShapedBuffer takes over the memory from the ScopedShapedBuffer. The
- // resulting ScopedShapedBuffer can only be destroyed.
- ShapedBuffer release();
+ // Sets the device memory buffer at the given index.
+ //
+ // If the given buffer's device memory is non-null, its device_ordinal and
+ // allocator must match those in `this`.
+ void set_buffer(OwningDeviceMemory buffer, const ShapeIndex& index) {
+ if (!buffer.is_null()) {
+ CHECK_EQ(buffer.device_ordinal(), device_ordinal());
+ CHECK_EQ(buffer.allocator(), allocator_);
+ *buffers_.mutable_element(index) = buffer.Forget();
+ } else {
+ *buffers_.mutable_element(index) = se::DeviceMemoryBase();
+ }
+ }
+
+ // Like unique_ptr::release(), creates and returns a regular ShapedBuffer from
+ // this ScopedShapedBuffer, without freeing any of the associated memory.
+ //
+ // It's the caller's job to ensure that the memory contained therein is freed.
+ TF_MUST_USE_RESULT ShapedBuffer release();
protected:
DeviceMemoryAllocator* allocator_;
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index 8b71a41509..3e7338fd13 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -196,9 +196,11 @@ StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
const ShapeIndex& index = pair.first;
se::DeviceMemoryBase& memory_base = pair.second;
const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index);
- TF_ASSIGN_OR_RETURN(memory_base,
+ TF_ASSIGN_OR_RETURN(auto memory,
allocator->Allocate(shaped_buffer.device_ordinal(),
GetByteSizeRequirement(subshape)));
+ // Move the allocated buffer into the ScopedShapedBuffer, which owns it.
+ memory_base = memory.Forget();
}
return std::move(shaped_buffer);
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index d82b4f0f81..55c544fcd2 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -81,7 +81,7 @@ class TransferManager {
// Transfers the given literal into the Infeed interface of the device,
// using the given executor.
virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
- const Literal& literal) = 0;
+ const LiteralSlice& literal) = 0;
// Transfers the given literal from the Outfeed interface of the device,
// using the given executor.
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index f7a5512fec..ba16dc640e 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -215,7 +215,7 @@ StatusOr<bool> TransposeFolding::Run(HloModule* module) {
std::make_pair(instruction, operand_indices));
}
}
- return tensorflow::Status::OK();
+ return Status::OK();
};
for (auto* comp : module->MakeNonfusionComputations()) {
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 0f16a592b6..9e62d0acfb 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -55,6 +55,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
return HloOpcode::kCos;
case UNOP_EXP:
return HloOpcode::kExp;
+ case UNOP_EXPM1:
+ return HloOpcode::kExpm1;
case UNOP_FLOOR:
return HloOpcode::kFloor;
case UNOP_IMAG:
@@ -63,6 +65,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
return HloOpcode::kIsFinite;
case UNOP_LOG:
return HloOpcode::kLog;
+ case UNOP_LOG1P:
+ return HloOpcode::kLog1p;
case UNOP_NOT:
return HloOpcode::kNot;
case UNOP_NEGATE:
diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h
index 4f64fe8f83..141347a792 100644
--- a/tensorflow/compiler/xla/service_interface.h
+++ b/tensorflow/compiler/xla/service_interface.h
@@ -16,9 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_
+#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/status.h"
namespace xla {
@@ -32,99 +32,93 @@ class ServiceInterface {
virtual ~ServiceInterface() = default;
// TODO(b/31824348): Convert to use StatusOr.
- virtual tensorflow::Status TransferToClient(
- const TransferToClientRequest* arg, TransferToClientResponse* result) = 0;
+ virtual Status TransferToClient(const TransferToClientRequest* arg,
+ TransferToClientResponse* result) = 0;
- virtual tensorflow::Status TransferToServer(
- const TransferToServerRequest* arg, TransferToServerResponse* result) = 0;
+ virtual Status TransferToServer(const TransferToServerRequest* arg,
+ TransferToServerResponse* result) = 0;
- virtual tensorflow::Status TransferToInfeed(
- const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) = 0;
+ virtual Status TransferToInfeed(const TransferToInfeedRequest* arg,
+ TransferToInfeedResponse* result) = 0;
- virtual tensorflow::Status TransferFromOutfeed(
- const TransferFromOutfeedRequest* arg,
- TransferFromOutfeedResponse* result) = 0;
+ virtual Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
+ TransferFromOutfeedResponse* result) = 0;
- virtual tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
- ResetDeviceResponse* result) = 0;
+ virtual Status ResetDevice(const ResetDeviceRequest* arg,
+ ResetDeviceResponse* result) = 0;
- virtual tensorflow::Status LoadComputationSnapshot(
+ virtual Status LoadComputationSnapshot(
const LoadComputationSnapshotRequest* request,
LoadComputationSnapshotResponse* result) = 0;
- virtual tensorflow::Status Execute(const ExecuteRequest* arg,
- ExecuteResponse* result) = 0;
+ virtual Status Execute(const ExecuteRequest* arg,
+ ExecuteResponse* result) = 0;
- virtual tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg,
- ExecuteResponse* result) = 0;
+ virtual Status ExecuteGraph(const ExecuteGraphRequest* arg,
+ ExecuteResponse* result) = 0;
- virtual tensorflow::Status ExecuteParallel(
- const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0;
+ virtual Status ExecuteParallel(const ExecuteParallelRequest* arg,
+ ExecuteParallelResponse* result) = 0;
- virtual tensorflow::Status ExecuteGraphParallel(
- const ExecuteGraphParallelRequest* arg,
- ExecuteParallelResponse* result) = 0;
+ virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
+ ExecuteParallelResponse* result) = 0;
- virtual tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg,
- ExecuteAsyncResponse* result) = 0;
+ virtual Status ExecuteAsync(const ExecuteAsyncRequest* arg,
+ ExecuteAsyncResponse* result) = 0;
- virtual tensorflow::Status WaitForExecution(
- const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) = 0;
+ virtual Status WaitForExecution(const WaitForExecutionRequest* arg,
+ WaitForExecutionResponse* result) = 0;
- virtual tensorflow::Status DeconstructTuple(
- const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) = 0;
+ virtual Status DeconstructTuple(const DeconstructTupleRequest* arg,
+ DeconstructTupleResponse* result) = 0;
- virtual tensorflow::Status GetComputationStats(
- const ComputationStatsRequest* arg, ComputationStatsResponse* result) = 0;
+ virtual Status GetComputationStats(const ComputationStatsRequest* arg,
+ ComputationStatsResponse* result) = 0;
- virtual tensorflow::Status GetComputationGraphStats(
+ virtual Status GetComputationGraphStats(
const ComputationGraphStatsRequest* arg,
ComputationStatsResponse* result) = 0;
- virtual tensorflow::Status GetComputationShape(
- const GetComputationShapeRequest* arg,
- GetComputationShapeResponse* result) = 0;
+ virtual Status GetComputationShape(const GetComputationShapeRequest* arg,
+ GetComputationShapeResponse* result) = 0;
- virtual tensorflow::Status GetShape(const GetShapeRequest* arg,
- GetShapeResponse* result) = 0;
+ virtual Status GetShape(const GetShapeRequest* arg,
+ GetShapeResponse* result) = 0;
- virtual tensorflow::Status CreateChannelHandle(
- const CreateChannelHandleRequest* arg,
- CreateChannelHandleResponse* result) = 0;
+ virtual Status CreateChannelHandle(const CreateChannelHandleRequest* arg,
+ CreateChannelHandleResponse* result) = 0;
- virtual tensorflow::Status GetDeviceHandles(
- const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) = 0;
+ virtual Status GetDeviceHandles(const GetDeviceHandlesRequest* arg,
+ GetDeviceHandlesResponse* result) = 0;
// Methods used by ComputationBuilder.
- virtual tensorflow::Status Computation(const ComputationRequest* arg,
- ComputationResponse* result) = 0;
+ virtual Status Computation(const ComputationRequest* arg,
+ ComputationResponse* result) = 0;
- virtual tensorflow::Status Op(const OpRequest* arg, OpResponse* result) = 0;
+ virtual Status Op(const OpRequest* arg, OpResponse* result) = 0;
- virtual tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg,
- GetLocalShapeResponse* result) = 0;
+ virtual Status GetLocalShape(const GetLocalShapeRequest* arg,
+ GetLocalShapeResponse* result) = 0;
- virtual tensorflow::Status SetReturnValue(
- const SetReturnValueRequest* arg, SetReturnValueResponse* results) = 0;
+ virtual Status SetReturnValue(const SetReturnValueRequest* arg,
+ SetReturnValueResponse* results) = 0;
- virtual tensorflow::Status IsConstant(const IsConstantRequest* arg,
- IsConstantResponse* result) = 0;
+ virtual Status IsConstant(const IsConstantRequest* arg,
+ IsConstantResponse* result) = 0;
- virtual tensorflow::Status ComputeConstant(
- const ComputeConstantRequest* arg, ComputeConstantResponse* result) = 0;
+ virtual Status ComputeConstant(const ComputeConstantRequest* arg,
+ ComputeConstantResponse* result) = 0;
- virtual tensorflow::Status ComputeConstantGraph(
- const ComputeConstantGraphRequest* arg,
- ComputeConstantResponse* result) = 0;
+ virtual Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
+ ComputeConstantResponse* result) = 0;
// Methods used by Computation.
- virtual tensorflow::Status SnapshotComputation(
- const SnapshotComputationRequest* ag,
- SnapshotComputationResponse* result) = 0;
+ virtual Status SnapshotComputation(const SnapshotComputationRequest* ag,
+ SnapshotComputationResponse* result) = 0;
// Methods used by GlobalData.
- virtual tensorflow::Status Unregister(const UnregisterRequest* arg,
- UnregisterResponse* result) = 0;
+ virtual Status Unregister(const UnregisterRequest* arg,
+ UnregisterResponse* result) = 0;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc
index 789eba5780..7ee366b27a 100644
--- a/tensorflow/compiler/xla/shape_layout.cc
+++ b/tensorflow/compiler/xla/shape_layout.cc
@@ -22,24 +22,24 @@ limitations under the License.
namespace xla {
-tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) {
+Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) {
if (!ShapeUtil::Compatible(other_shape, shape_)) {
return InvalidArgument("Shape %s is not compatible with shape %s",
ShapeUtil::HumanString(other_shape).c_str(),
ShapeUtil::HumanString(shape()).c_str());
}
shape_ = other_shape;
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const {
+Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const {
if (!ShapeUtil::Compatible(*to_shape, shape_)) {
return InvalidArgument("Shape %s is not compatible with shape %s",
ShapeUtil::HumanString(*to_shape).c_str(),
ShapeUtil::HumanString(shape()).c_str());
}
*to_shape = shape_;
- return tensorflow::Status::OK();
+ return Status::OK();
}
void ShapeLayout::SetToDefaultLayout() {
diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h
index a1dce758cd..36806da599 100644
--- a/tensorflow/compiler/xla/shape_layout.h
+++ b/tensorflow/compiler/xla/shape_layout.h
@@ -40,7 +40,7 @@ class ShapeLayout {
// Assigns the layouts in this ShapeLayout to the Layout fields of the given
// shape. 'to_shape' and the shape of the ShapeLayout object must be
// compatible.
- tensorflow::Status AssignLayoutToShape(Shape* to_shape) const;
+ Status AssignLayoutToShape(Shape* to_shape) const;
// Returns true if the Layouts in this ShapeLayout match the layouts in the
// given shape. Returns false otherwise. If the given shape is not compatible
@@ -49,7 +49,7 @@ class ShapeLayout {
// Copies the layout from the given shape into this ShapeLayout. 'other_shape'
// must be compatible with the ShapeLayout's shape.
- tensorflow::Status CopyLayoutFromShape(const Shape& other_shape);
+ Status CopyLayoutFromShape(const Shape& other_shape);
// Clears (Layout::Clear) all the Layouts stored in this object.
void Clear();
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index cb8bf5a2b9..82c75f85d8 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -231,7 +231,7 @@ class ShapeUtil {
}
// Returns the higher-precision element type if a and b are both floating
- // point types; otherwise, checks that that they have the same element type
+ // point types; otherwise, checks that they have the same element type
// and returns it.
static PrimitiveType HigherPrecisionElementType(const Shape& a,
const Shape& b) {
diff --git a/tensorflow/compiler/xla/status.h b/tensorflow/compiler/xla/status.h
index 4eb3bf3766..69abb51852 100644
--- a/tensorflow/compiler/xla/status.h
+++ b/tensorflow/compiler/xla/status.h
@@ -21,7 +21,7 @@ limitations under the License.
namespace xla {
-using tensorflow::Status;
+using tensorflow::Status; // TENSORFLOW_STATUS_OK
} // namespace xla
diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc
index 7d76370e85..377a618ffb 100644
--- a/tensorflow/compiler/xla/statusor_test.cc
+++ b/tensorflow/compiler/xla/statusor_test.cc
@@ -413,7 +413,7 @@ TEST(StatusOr, TestPointerValueConst) {
EXPECT_EQ(&kI, thing.ValueOrDie());
}
-// NOTE(tucker): tensorflow::StatusOr does not support this kind
+// NOTE(tucker): StatusOr does not support this kind
// of resize op.
// TEST(StatusOr, StatusOrVectorOfUniquePointerCanResize) {
// using EvilType = std::vector<std::unique_ptr<int>>;
diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h
index 17bae2e4f6..8918350135 100644
--- a/tensorflow/compiler/xla/test_helpers.h
+++ b/tensorflow/compiler/xla/test_helpers.h
@@ -40,13 +40,10 @@ class Literal;
namespace testing {
namespace internal_status {
-inline const ::tensorflow::Status& GetStatus(
- const ::tensorflow::Status& status) {
- return status;
-}
+inline const Status& GetStatus(const Status& status) { return status; }
template <typename T>
-inline const ::tensorflow::Status& GetStatus(const StatusOr<T>& status) {
+inline const Status& GetStatus(const StatusOr<T>& status) {
return status.status();
}
} // namespace internal_status
@@ -57,21 +54,17 @@ inline const ::tensorflow::Status& GetStatus(const StatusOr<T>& status) {
// The following macros are similar to macros in gmock, but deliberately named
// differently in order to avoid conflicts in files which include both.
-// Macros for testing the results of functions that return tensorflow::Status or
+// Macros for testing the results of functions that return Status or
// StatusOr<T> (for any type T).
-#define EXPECT_IS_OK(expression) \
- EXPECT_EQ(tensorflow::Status::OK(), \
- xla::testing::internal_status::GetStatus(expression))
-#define EXPECT_IS_NOT_OK(expression) \
- EXPECT_NE(tensorflow::Status::OK(), \
- xla::testing::internal_status::GetStatus(expression))
+#define EXPECT_IS_OK(expression) \
+ EXPECT_EQ(Status::OK(), xla::testing::internal_status::GetStatus(expression))
+#define EXPECT_IS_NOT_OK(expression) \
+ EXPECT_NE(Status::OK(), xla::testing::internal_status::GetStatus(expression))
#undef ASSERT_IS_OK
-#define ASSERT_IS_OK(expression) \
- ASSERT_EQ(tensorflow::Status::OK(), \
- xla::testing::internal_status::GetStatus(expression))
+#define ASSERT_IS_OK(expression) \
+ ASSERT_EQ(Status::OK(), xla::testing::internal_status::GetStatus(expression))
#undef ASSERT_IS_NOT_OK
-#define ASSERT_IS_NOT_OK(expression) \
- ASSERT_NE(tensorflow::Status::OK(), \
- xla::testing::internal_status::GetStatus(expression))
+#define ASSERT_IS_NOT_OK(expression) \
+ ASSERT_NE(Status::OK(), xla::testing::internal_status::GetStatus(expression))
#endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index b982cf0dbc..7a528a2247 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -87,6 +87,8 @@ cc_library(
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:error_spec",
+ "//tensorflow/compiler/xla:literal_comparison",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -152,7 +154,6 @@ tf_cc_binary(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
@@ -188,8 +189,6 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -288,8 +287,6 @@ xla_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -313,7 +310,6 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -335,7 +331,6 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -378,7 +373,6 @@ xla_test(
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -398,7 +392,6 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -422,8 +415,6 @@ xla_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
@@ -450,8 +441,6 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -472,7 +461,6 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -491,7 +479,6 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -528,7 +515,6 @@ xla_test(
tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
@@ -552,7 +538,6 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -572,8 +557,6 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -598,8 +581,6 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -626,7 +607,6 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -697,7 +677,6 @@ xla_test(
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -741,7 +720,6 @@ xla_test(
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -766,7 +744,6 @@ xla_test(
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -790,7 +767,6 @@ xla_test(
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -843,7 +819,6 @@ xla_test(
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -868,7 +843,6 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -930,8 +904,6 @@ xla_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
@@ -960,8 +932,6 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
@@ -1002,7 +972,6 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1055,8 +1024,6 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1078,7 +1045,6 @@ xla_test(
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -1108,8 +1074,6 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
@@ -1240,8 +1204,6 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
@@ -1281,7 +1243,6 @@ xla_test(
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:reference_util",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1304,7 +1265,6 @@ xla_test(
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1344,7 +1304,6 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1362,7 +1321,6 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1388,8 +1346,6 @@ xla_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1411,7 +1367,6 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1483,8 +1438,6 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
@@ -1532,7 +1485,6 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1574,8 +1526,6 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -1596,7 +1546,6 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1620,8 +1569,6 @@ xla_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1642,7 +1589,6 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -1661,7 +1607,6 @@ xla_test(
srcs = ["execution_profile_test.cc"],
deps = [
":client_library_test_base",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1676,7 +1621,6 @@ xla_test(
args = ["--xla_hlo_profile"],
deps = [
":client_library_test_base",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1782,8 +1726,6 @@ xla_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1811,8 +1753,6 @@ xla_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1850,8 +1790,6 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1880,8 +1818,6 @@ xla_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
@@ -1949,8 +1885,6 @@ xla_test(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -2051,7 +1985,6 @@ xla_test(
":local_client_test_base",
":test_utils",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc
index 6ebbf71918..51b9f0d3e3 100644
--- a/tensorflow/compiler/xla/tests/broadcast_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_test.cc
@@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(*Literal::CreateR0<float>(42.0), *result,
- error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0<float>(42.0), *result,
+ error_spec_));
}
XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
@@ -62,9 +62,9 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(
+ EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
- error_spec_);
+ error_spec_));
}
XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
@@ -85,13 +85,13 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(
+ EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
- LiteralView::Create(*result, {0}), error_spec_);
+ LiteralSlice(*result, {0}), error_spec_));
- LiteralTestUtil::ExpectNear(
+ EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
- LiteralView::Create(*result, {1}), error_spec_);
+ LiteralSlice(*result, {1}), error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
@@ -106,9 +106,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(
- *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
- error_spec_);
+ EXPECT_TRUE(
+ LiteralTestUtil::Near(*Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ *result, error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
@@ -125,9 +125,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(
- *Literal::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
- error_spec_);
+ EXPECT_TRUE(
+ LiteralTestUtil::Near(*Literal::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}),
+ *result, error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
@@ -142,10 +142,10 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(
+ EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
{{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
- *result, error_spec_);
+ *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
@@ -166,8 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
Array2D<float> pz({{1, 2}, {1, 2}});
expected.FillWithPZ(pz);
- LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
- *result, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
@@ -196,8 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
}
expected.FillWithYX(yx);
- LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
- *result, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
@@ -218,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result,
- error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array),
+ *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
@@ -238,8 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
Array4D<float> expected(64, 64, 3, 3);
expected.Fill(1.0f);
- LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
- *result, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
@@ -260,8 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
Array4D<float> expected(3, 3, 2, 2);
expected.FillWithYX(to_broadcast);
- LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
- *result, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
@@ -291,8 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
- *result, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc
index a43ca3d5ca..5fd33b50c9 100644
--- a/tensorflow/compiler/xla/tests/call_test.cc
+++ b/tensorflow/compiler/xla/tests/call_test.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include <memory>
#include <utility>
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 41f9a5f666..bf8ed4d9fb 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
@@ -178,8 +177,7 @@ void ClientLibraryTestBase::ComputeAndCompareLiteral(
error, shape_with_layout));
}
-tensorflow::Status
-ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
+Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
const xla::XlaComputation& computation, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const std::function<void(const Literal& actual,
@@ -201,11 +199,10 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
"Test with output layout: ",
ShapeUtil::HumanStringWithLayout(layout)));
} while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
- return tensorflow::Status::OK();
+ return Status::OK();
}
-tensorflow::Status
-ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
+Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
const xla::XlaComputation& computation, const Literal& /*expected*/,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const std::function<void(const Literal& actual,
@@ -216,8 +213,8 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
// This is a recursive function. It's an std::function instead of a lambda
// because it needs to capture itself. The index is the index of the argument
// to try all layouts for.
- std::function<tensorflow::Status(int64)> choose;
- choose = [&, this](int64 index) -> tensorflow::Status {
+ std::function<Status(int64)> choose;
+ choose = [&, this](int64 index) -> Status {
if (index < arguments.size()) {
// Try out all layouts for the operand.
TF_ASSIGN_OR_RETURN(auto literal,
@@ -230,7 +227,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
TF_RETURN_IF_ERROR(choose(index + 1));
arguments_with_layout.pop_back();
layout_strings.pop_back();
- return tensorflow::Status::OK();
+ return Status::OK();
}
std::vector<int64> minor_to_major(ShapeUtil::Rank(literal->shape()));
@@ -248,7 +245,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
layout_strings.pop_back();
} while (
std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
- return tensorflow::Status::OK();
+ return Status::OK();
}
// Every argument has an assigned layout.
@@ -263,13 +260,13 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
tensorflow::strings::StrAppend(&error_message, str, " ");
}
verify_output(*actual, error_message);
- return tensorflow::Status::OK();
+ return Status::OK();
};
return choose(0);
}
-tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
+Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in,
const Shape* shape_with_layout) {
@@ -297,7 +294,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
std::unique_ptr<Literal> converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
- converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected);
+ converted_expected = Literal::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get();
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
@@ -311,7 +308,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
}
auto expect_equal = [&](const Literal& actual, const string& error_message) {
- LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message;
};
if (execution_options_.debug_options().xla_test_all_output_layouts()) {
return ComputeAndCompareLiteralWithAllOutputLayouts(
@@ -323,11 +320,11 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- LiteralTestUtil::ExpectEqual(*expected_ptr, *actual);
- return tensorflow::Status::OK();
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual));
+ return Status::OK();
}
-tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
+Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in,
ErrorSpec error, const Shape* shape_with_layout) {
@@ -349,7 +346,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
std::unique_ptr<Literal> converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
- converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected);
+ converted_expected = Literal::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get();
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
@@ -363,7 +360,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
}
auto expect_near = [&](const Literal& actual, const string& error_message) {
- LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message);
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error))
+ << error_message;
};
if (execution_options_.debug_options().xla_test_all_output_layouts()) {
return ComputeAndCompareLiteralWithAllOutputLayouts(
@@ -375,8 +373,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error);
- return tensorflow::Status::OK();
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error));
+ return Status::OK();
}
void ClientLibraryTestBase::ComputeAndCompareR1U8(
@@ -407,7 +405,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- LiteralTestUtil::ExpectEqual(expected, *actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual));
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
@@ -419,7 +417,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- LiteralTestUtil::ExpectNear(expected, *actual, error);
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error));
}
void ClientLibraryTestBase::ComputeAndCompare(
@@ -431,7 +429,7 @@ void ClientLibraryTestBase::ComputeAndCompare(
}
std::unique_ptr<Literal> reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- LiteralTestUtil::ExpectEqual(*reference, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result));
}
void ClientLibraryTestBase::ComputeAndCompare(
@@ -444,7 +442,7 @@ void ClientLibraryTestBase::ComputeAndCompare(
}
std::unique_ptr<Literal> reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- LiteralTestUtil::ExpectNear(*reference, *result, error);
+ EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error));
}
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
@@ -562,7 +560,7 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
XlaBuilder* builder) {
return builder->ConstantLiteral(
- use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
+ use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal);
}
std::unique_ptr<GlobalData>
@@ -583,7 +581,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(
const Literal* param_literal = &literal;
std::unique_ptr<Literal> converted_literal;
if (use_bfloat16_) {
- converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
+ converted_literal = Literal::ConvertF32ToBF16(literal);
param_literal = converted_literal.get();
}
std::unique_ptr<GlobalData> data =
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 16e838e60f..0499fec589 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -188,11 +188,11 @@ class ClientLibraryTestBase : public ::testing::Test {
const Shape* shape_with_layout = nullptr);
// ComputeAndCompare variant which returns an error status.
- tensorflow::Status ComputeAndCompareLiteralWithStatus(
+ Status ComputeAndCompareLiteralWithStatus(
XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_layout = nullptr);
- tensorflow::Status ComputeAndCompareLiteralWithStatus(
+ Status ComputeAndCompareLiteralWithStatus(
XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
const Shape* shape_with_layout = nullptr);
@@ -378,12 +378,12 @@ class ClientLibraryTestBase : public ::testing::Test {
ExecutionOptions execution_options_;
private:
- tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts(
+ Status ComputeAndCompareLiteralWithAllOutputLayouts(
const xla::XlaComputation& computation, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const std::function<void(const Literal& actual,
const string& error_message)>& verify_output);
- tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts(
+ Status ComputeAndCompareLiteralWithAllInputLayouts(
const xla::XlaComputation& computation, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const std::function<void(const Literal& actual,
@@ -541,7 +541,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR0(value);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+ literal = Literal::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
@@ -555,7 +555,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR1(values);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+ literal = Literal::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
@@ -569,7 +569,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+ literal = Literal::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
@@ -583,7 +583,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+ literal = Literal::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index 0b425b93bb..08671cf624 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -62,9 +62,9 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
TF_ASSERT_OK_AND_ASSIGN(
auto computed, client_->Transfer(*data, &expected_literal->shape()));
- LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(),
- computed->shape());
- LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
+ ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
+ expected_literal->shape(), computed->shape()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
}
}
}
@@ -91,9 +91,9 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
auto result,
client_->ExecuteAndTransfer(computation, {}, &execution_options));
LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
- LiteralView::Create(*result, {0}));
+ LiteralSlice(*result, {0}));
LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
- LiteralView::Create(*result, {1}));
+ LiteralSlice(*result, {1}));
EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
@@ -142,7 +142,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
auto result_literal,
client_->Transfer(*results[0], &expected_result->shape()));
- LiteralTestUtil::ExpectEqual(*expected_result, *result_literal);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
index ecce599a8a..50a0069648 100644
--- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -17,7 +17,6 @@ limitations under the License.
#include <memory>
#include <string>
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
@@ -50,8 +49,8 @@ class CompilationCacheTest : public ClientLibraryTestBase {
/*execution_options=*/&execution_options_,
&execution_profile)
.ConsumeValueOrDie();
- LiteralTestUtil::ExpectNear(*Literal::CreateR0<float>(expected_result),
- *result, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR0<float>(expected_result), *result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
@@ -67,8 +66,8 @@ class CompilationCacheTest : public ClientLibraryTestBase {
.ConsumeValueOrDie();
std::unique_ptr<Literal> result =
client_->Transfer(*data_handle).ConsumeValueOrDie();
- LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>(expected_result),
- *result, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR2<float>(expected_result), *result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index bf4b8fb0bc..ba22530f1c 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -208,7 +208,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
ComputeConstantLiteral(client, computation, &b));
std::unique_ptr<Literal> expected_literal =
Literal::CreateR1<int32>({4, 6});
- LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
}
}
@@ -222,7 +222,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
std::unique_ptr<Literal> expected_literal = Literal::CreateR0<int32>(5);
- LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
}
}
@@ -244,9 +244,9 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
std::unique_ptr<Literal> expected_literal =
Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
LayoutUtil::MakeLayout(layout));
- LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(),
- computed->shape());
- LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
+ ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
+ expected_literal->shape(), computed->shape()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
}
}
}
diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc
index 4743673561..fa963b175f 100644
--- a/tensorflow/compiler/xla/tests/constants_test.cc
+++ b/tensorflow/compiler/xla/tests/constants_test.cc
@@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
@@ -169,9 +168,9 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) {
ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
LiteralTestUtil::ExpectR2Near<float>(
- {{1.0}, {2.0}}, LiteralView::Create(*result, {0}), error_spec_);
+ {{1.0}, {2.0}}, LiteralSlice(*result, {0}), error_spec_);
LiteralTestUtil::ExpectR1Near<float>(
- {2.0, 42.0}, LiteralView::Create(*result, {1}), error_spec_);
+ {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_);
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
index 50d6e25d86..fea850dc13 100644
--- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index 155fbacf58..2b3390ca98 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -49,7 +49,7 @@ class CopyOpTest : public HloTestBase {
module->AddEntryComputation(std::move(computation));
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectEqual(literal, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result));
}
void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
@@ -253,7 +253,7 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) {
auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
.ConsumeValueOrDie();
- LiteralTestUtil::ExpectEqual(*empty, *actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc
index c76e5aabf4..bfe688e20d 100644
--- a/tensorflow/compiler/xla/tests/deallocation_test.cc
+++ b/tensorflow/compiler/xla/tests/deallocation_test.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
index d0ada24748..12789fe665 100644
--- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include <memory>
#include <vector>
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index efa5aed2d1..0fd846cef8 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -61,7 +61,7 @@ using TypesF16F32F64CF64 = ::testing::Types<Eigen::half, float>;
#endif
// Check that we can safely pass an input tuple's elements to a dot operation.
-TEST_F(DotOperationTest, DotOfInputTupleElem) {
+XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
XlaBuilder builder(TestName());
XlaOp param;
@@ -798,7 +798,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
this->error_spec_);
}
-TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) {
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
std::unique_ptr<Array2D<float>> constant_rhs_array(
@@ -826,7 +826,7 @@ TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) {
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
std::unique_ptr<Array2D<float>> constant_rhs_array(
@@ -855,7 +855,7 @@ TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
}
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-TEST_F(DotOperationTest,
+XLA_TEST_F(DotOperationTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
DotOfGatherOptimizationWithConstRHSReverseMM)))) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
@@ -886,7 +886,7 @@ TEST_F(DotOperationTest,
}
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-TEST_F(DotOperationTest,
+XLA_TEST_F(DotOperationTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
DotOfGatherOptimizationWithConstLHSReverseMM)))) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
@@ -917,7 +917,7 @@ TEST_F(DotOperationTest,
}
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-TEST_F(DotOperationTest,
+XLA_TEST_F(DotOperationTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(
DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
@@ -953,7 +953,7 @@ TEST_F(DotOperationTest,
}
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-TEST_F(DotOperationTest,
+XLA_TEST_F(DotOperationTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(
DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
@@ -989,7 +989,7 @@ TEST_F(DotOperationTest,
}
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-TEST_F(DotOperationTest,
+XLA_TEST_F(DotOperationTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(
DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) {
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
@@ -1017,7 +1017,7 @@ TEST_F(DotOperationTest,
}
// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-TEST_F(DotOperationTest,
+XLA_TEST_F(DotOperationTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(
DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) {
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index b947f8208a..e6f79b5ac5 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -118,9 +118,9 @@ class FusionTest : public HloTestBase {
auto expected = Literal::CreateR2FromArray2D(answer_data);
auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
if (primitive_util::IsFloatingPointType(prim_type)) {
- LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4)));
} else {
- LiteralTestUtil::ExpectEqual(*expected, *actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
}
}
@@ -221,9 +221,9 @@ XLA_TEST_F(FusionTest, Test) {
const4, reshape3, add2, const1, const0},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{0.5}, {2.72}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}),
- ErrorSpec(1e-4));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR2<float>({{0.5}, {2.72}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
// Test whether we emit appropriate code for parameters of fusion instructions.
@@ -247,9 +247,9 @@ XLA_TEST_F(FusionTest, Parameter) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}),
- ErrorSpec(1e-4));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
@@ -307,9 +307,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectNear(
+ EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4));
+ *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
XLA_TEST_F(FusionTest, ReshapeToScalar) {
@@ -322,8 +322,9 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(5),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR0<int32>(5),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
@@ -336,9 +337,9 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
@@ -351,9 +352,9 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
@@ -366,8 +367,9 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape__1by1by1) {
@@ -380,8 +382,9 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR3<int32>({{{7}}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{7}}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape__) {
@@ -394,8 +397,9 @@ XLA_TEST_F(FusionTest, Reshape__) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
@@ -408,9 +412,9 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Transpose_2by3) {
@@ -423,9 +427,9 @@ XLA_TEST_F(FusionTest, Transpose_2by3) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Transpose_3by3) {
@@ -438,9 +442,9 @@ XLA_TEST_F(FusionTest, Transpose_3by3) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reverse) {
@@ -454,8 +458,9 @@ XLA_TEST_F(FusionTest, Reverse) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({3, 2, 1}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR1<int32>({3, 2, 1}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReverseNegate) {
@@ -471,8 +476,9 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-3, -2, -1}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-3, -2, -1}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, BroadcastNegate) {
@@ -488,8 +494,9 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -1}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -1}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, SliceNegate) {
@@ -505,8 +512,9 @@ XLA_TEST_F(FusionTest, SliceNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -3}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -3}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DynamicSliceNegate) {
@@ -526,8 +534,9 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) {
/*instructions_to_fuse=*/{negate3, dynamic_slice2},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-2, -3}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-2, -3}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReshapeNegate) {
@@ -543,8 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
// TODO(b/64070202): Investigate failure.
@@ -561,8 +571,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
std::unique_ptr<HloComputation> MakeReduceTestComputation() {
@@ -591,8 +602,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(15),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR0<int32>(15),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
@@ -612,8 +624,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(-15),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR0<int32>(-15),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
@@ -661,9 +674,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
// When a constant (or other op) which has multiple users is imported
@@ -697,8 +710,9 @@ XLA_TEST_F(FusionTest, SharedConstant) {
// fused instruction contains the constant(2), the parameter, and 4 adds
EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({8}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR1<int32>({8}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index 130456e61c..4854c649c1 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -629,8 +629,8 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
client_->ExecuteParallel(computation_instances));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
client_->Transfer(*(result_data[0])));
- LiteralTestUtil::ExpectEqual(
- *result_literal, *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *result_literal, *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}})));
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index c28f79ae38..cde1dcd9cd 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -15,978 +15,93 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
-#include <unistd.h>
-#include <cmath>
-#include <vector>
-
-#include "tensorflow/compiler/xla/index_util.h"
-#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/compiler/xla/literal_comparison.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/types.h"
namespace xla {
-using ::tensorflow::strings::Appendf;
-using ::tensorflow::strings::Printf;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-
-/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes(
- const Shape& expected, const Shape& actual) {
- if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) {
- return ::testing::AssertionFailure()
- << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected)
- << " got: " << ShapeUtil::HumanString(actual);
- }
- if (ShapeUtil::IsTuple(expected)) {
- if (ShapeUtil::TupleElementCount(expected) !=
- ShapeUtil::TupleElementCount(actual)) {
- return ::testing::AssertionFailure()
- << "want tuple element count: "
- << ShapeUtil::TupleElementCount(expected)
- << " got tuple element count: "
- << ShapeUtil::TupleElementCount(actual);
- }
- for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
- ::testing::AssertionResult result =
- EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i))
- << "mismatch in tuple index " << i;
- if (!result) {
- return result;
- }
- }
- } else {
- if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
- return ::testing::AssertionFailure()
- << "want rank of: " << ShapeUtil::HumanString(expected)
- << " got rank of: " << ShapeUtil::HumanString(actual);
- }
- if (expected.element_type() != actual.element_type()) {
- return ::testing::AssertionFailure()
- << PrimitiveType_Name(expected.element_type()) << " vs "
- << PrimitiveType_Name(actual.element_type());
- }
- if (expected.dimensions_size() != actual.dimensions_size()) {
- return ::testing::AssertionFailure()
- << "want dimensions_size " << expected.dimensions_size()
- << " got dimensions_size " << actual.dimensions_size();
- }
- for (int i = 0; i < expected.dimensions_size(); ++i) {
- if (expected.dimensions(i) != actual.dimensions(i)) {
- return ::testing::AssertionFailure()
- << "mismatch in dimension #" << i
- << " expected: " << ShapeUtil::HumanString(expected)
- << " actual: " << ShapeUtil::HumanString(actual);
- }
- }
- }
- return ::testing::AssertionSuccess();
-}
-
-/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected,
- const Shape& actual) {
- ASSERT_TRUE(EqualShapes(expected, actual));
-}
-
-/* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts(
- const Shape& expected, const Shape& actual) {
- ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString());
-}
-
-namespace {
-
-// Return a literal with all arrays of type FromNativeT converted to type
-// ToNativeT in the given literal.
-template <typename FromNativeT, typename ToNativeT>
-std::unique_ptr<Literal> ConvertType(const Literal& literal) {
- // First construct shape of the result.
- Shape result_shape(literal.shape());
- ShapeUtil::ForEachMutableSubshape(
- &result_shape, [](Shape* subshape, const ShapeIndex&) {
- if (subshape->element_type() ==
- primitive_util::NativeToPrimitiveType<FromNativeT>()) {
- subshape->set_element_type(
- primitive_util::NativeToPrimitiveType<ToNativeT>());
- }
- });
- auto result = MakeUnique<Literal>(result_shape);
-
- // Then copy over the data from 'literal' converting FromNativeT values to
- // ToNativeT values as necessary.
- ShapeUtil::ForEachSubshape(
- literal.shape(),
- [&](const Shape& subshape, const ShapeIndex& shape_index) {
- if (ShapeUtil::IsArray(subshape)) {
- if (subshape.element_type() ==
- primitive_util::NativeToPrimitiveType<FromNativeT>()) {
- auto src = literal.data<FromNativeT>(shape_index);
- auto dest = result->data<ToNativeT>(shape_index);
- for (int64 i = 0; i < src.size(); ++i) {
- dest[i] = static_cast<ToNativeT>(src[i]);
- }
- } else {
- TF_CHECK_OK(result->CopyFrom(literal,
- /*dest_shape_index=*/shape_index,
- /*src_shape_index=*/shape_index));
- }
- }
- });
- return result;
-}
-
-} // namespace
-
-/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32(
- const Literal& literal) {
- return ConvertType<bfloat16, float>(literal);
-}
-
-/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16(
- const Literal& literal) {
- return ConvertType<float, bfloat16>(literal);
-}
-
namespace {
-string Hostname() {
- char hostname[1024];
- gethostname(hostname, sizeof hostname);
- hostname[sizeof hostname - 1] = 0;
- return string(hostname);
-}
-
-// Helper function for comparing a floating point type, FloatT, bitwise equal
-// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
-// -- on miscompare, a nice error message is given in the AssertionFailure.
-template <typename FloatT, typename UnsignedT>
-::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
- auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
- auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
- auto lhs_double = static_cast<double>(lhs);
- auto rhs_double = static_cast<double>(rhs);
- if (ulhs != urhs) {
- return ::testing::AssertionFailure() << Printf(
- "floating values are not bitwise-equal; and equality testing "
- "was requested: %s=%g=%a vs %s=%g=%a",
- StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double,
- lhs_double, StrCat(tensorflow::strings::Hex(urhs)).c_str(),
- rhs_double, rhs_double);
- }
- return ::testing::AssertionSuccess();
-}
-
-// Templated comparator that specializes for float equality comparison with the
-// bitwise helper above (this is the un-specialized fallback, to just use the
-// default gunit implementation).
-template <typename NativeT>
-::testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) {
- if (lhs == rhs) {
+// Writes the given literal to a file in the test temporary directory.
+void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) {
+ auto get_hostname = [] {
+ char hostname[1024];
+ gethostname(hostname, sizeof hostname);
+ hostname[sizeof hostname - 1] = 0;
+ return string(hostname);
+ };
+ int64 now_usec = tensorflow::Env::Default()->NowMicros();
+ string filename = tensorflow::io::JoinPath(
+ tensorflow::testing::TmpDir(),
+ tensorflow::strings::Printf("tempfile-%s-%llx-%s", get_hostname().c_str(),
+ now_usec, name.c_str()));
+ TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename,
+ literal.ToProto()));
+ LOG(ERROR) << "wrote to " << name << " file: " << filename;
+}
+
+// Callback helper that dumps literals to temporary files in the event of a
+// miscomparison.
+void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual,
+ const LiteralSlice& mismatches) {
+ LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " "
+ << literal_comparison::ToStringTruncated(expected);
+ LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) << " "
+ << literal_comparison::ToStringTruncated(actual);
+ LOG(INFO) << "Dumping literals to temp files...";
+ WriteLiteralToTempFile(expected, "expected");
+ WriteLiteralToTempFile(actual, "actual");
+ WriteLiteralToTempFile(mismatches, "mismatches");
+}
+
+::testing::AssertionResult StatusToAssertion(const Status& s) {
+ if (s.ok()) {
return ::testing::AssertionSuccess();
}
- ::testing::Message msg;
- msg << "Expected equality of these values:";
- msg << "\n " << lhs;
- msg << "\n " << rhs;
-
- return ::testing::AssertionFailure() << msg;
-}
-
-// Specializations for floating types that do bitwise comparisons when equality
-// comparison is requested.
-template <>
-::testing::AssertionResult CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
- return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
-}
-template <>
-::testing::AssertionResult CompareEqual<Eigen::half>(Eigen::half lhs,
- Eigen::half rhs) {
- return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs);
-}
-template <>
-::testing::AssertionResult CompareEqual<float>(float lhs, float rhs) {
- return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
-}
-template <>
-::testing::AssertionResult CompareEqual<double>(double lhs, double rhs) {
- return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
-}
-template <>
-::testing::AssertionResult CompareEqual<complex64>(complex64 lhs,
- complex64 rhs) {
- auto res = CompareEqual<float>(lhs.real(), rhs.real());
- if (!res) {
- return res;
- }
- return CompareEqual<float>(lhs.imag(), rhs.imag());
-}
-
-// A recursive function which iterates through every index of expected and
-// actual literal and compares their values elementwise. Returns true if all
-// elements are equal.
-template <typename NativeT>
-bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
- tensorflow::gtl::MutableArraySlice<int64> multi_index,
- int64 dimension) {
- if (dimension == expected.shape().dimensions_size()) {
- NativeT expected_value = expected.Get<NativeT>(multi_index);
- NativeT actual_value = actual.Get<NativeT>(multi_index);
- ::testing::AssertionResult result =
- CompareEqual<NativeT>(expected_value, actual_value);
- return result; // Defines implicit coersion to bool.
- }
-
- bool all_match = true;
- for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
- multi_index[dimension] = i;
- all_match = all_match && ExpectLiteralsEqual<NativeT>(
- expected, actual, multi_index, dimension + 1);
- }
- return all_match;
+ return ::testing::AssertionFailure() << s.error_message();
}
} // namespace
-/* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected,
- const Literal& actual,
- const string& message) {
- EXPECT_TRUE(Equal(expected, actual))
- << "expected:\n"
- << expected.ToString() << "\n\tvs actual:\n"
- << actual.ToString()
- << (message.empty() ? "" : StrCat("\nmessage: ", message));
-}
-
-/* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected,
- const Literal& actual) {
- EXPECT_FALSE(Equal(expected, actual));
-}
-
-/* static */ ::testing::AssertionResult LiteralTestUtil::Equal(
- const Literal& expected, const Literal& actual) {
- VLOG(1) << "expected:";
- XLA_VLOG_LINES(1, expected.ToString());
- VLOG(1) << "actual:";
- XLA_VLOG_LINES(1, actual.ToString());
-
- AssertEqualShapes(expected.shape(), actual.shape());
- std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
- bool match = false;
- switch (expected.shape().element_type()) {
- case PRED:
- match = ExpectLiteralsEqual<bool>(expected, actual, &multi_index, 0);
- break;
- case U8:
- match = ExpectLiteralsEqual<uint8>(expected, actual, &multi_index, 0);
- break;
- case S32:
- match = ExpectLiteralsEqual<int32>(expected, actual, &multi_index, 0);
- break;
- case S64:
- match = ExpectLiteralsEqual<int64>(expected, actual, &multi_index, 0);
- break;
- case U32:
- match = ExpectLiteralsEqual<uint32>(expected, actual, &multi_index, 0);
- break;
- case U64:
- match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0);
- break;
- case BF16:
- match = ExpectLiteralsEqual<bfloat16>(expected, actual, &multi_index, 0);
- break;
- case F16:
- match = ExpectLiteralsEqual<half>(expected, actual, &multi_index, 0);
- break;
- case F32:
- match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
- break;
- case F64:
- match = ExpectLiteralsEqual<double>(expected, actual, &multi_index, 0);
- break;
- case C64:
- match = ExpectLiteralsEqual<complex64>(expected, actual, &multi_index, 0);
- break;
- case TUPLE: {
- bool tuple_match = true;
- for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
- SCOPED_TRACE(StrCat("Tuple index ", i, " in ",
- ShapeUtil::HumanString(expected.shape())));
-
- // Create LiteralViews of the expected and actual elements.
- auto result = Equal(LiteralView::Create(expected, {i}),
- LiteralView::Create(actual, {i}));
- tuple_match = tuple_match ? !!result : false;
- }
- match = tuple_match;
- break;
- }
- default:
- LOG(FATAL)
- << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
- << PrimitiveType_Name(expected.shape().element_type());
- }
- ::testing::AssertionResult result = ::testing::AssertionSuccess();
- if (!match) {
- result = ::testing::AssertionFailure()
- << "expected: " << expected.ToString()
- << "\nactual: " << actual.ToString();
- VLOG(1) << result.message();
- }
- return result;
-}
-
-namespace {
-
-// Gets the total element count. For tuples, this is not the count of tuple
-// elements, but the sum of elements of each tuple element.
-int64 RecursiveElementCount(const Shape& shape) {
- if (ShapeUtil::IsTuple(shape)) {
- const int64 tuple_elements = ShapeUtil::TupleElementCount(shape);
- int64 total = 0;
- for (int64 i = 0; i < tuple_elements; ++i) {
- total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i));
- }
- return total;
- } else {
- return ShapeUtil::ElementsIn(shape);
- }
-}
-
-// Calling ToString on a literal with over 100 million elements takes around
-// 3 minutes. The utility of printing a literal with >1000 elements is
-// questionable, especially when writing the Literal proto to disk is orders
-// of magnitude faster.
-string TruncateHugeLiteral(const Literal& literal) {
- return RecursiveElementCount(literal.shape()) < 1000
- ? literal.ToString()
- : "[TRUNCATED, Literal with more than 1000 values]";
+/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes(
+ const Shape& expected, const Shape& actual) {
+ return StatusToAssertion(literal_comparison::EqualShapes(expected, actual));
}
-// Returns whether the actual and expected values are mismatched with respect to
-// nans. 'relaxed_nans' is interpreted as in xla::ErrorSpec.
-template <typename NativeT>
-bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) {
- if (relaxed_nans) {
- return !std::isnan(expected) && std::isnan(actual);
- } else {
- return std::isnan(expected) != std::isnan(actual);
+/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapesAndLayouts(
+ const Shape& expected, const Shape& actual) {
+ if (expected.ShortDebugString() != actual.ShortDebugString()) {
+ return ::testing::AssertionFailure()
+ << "want: " << expected.ShortDebugString()
+ << " got: " << actual.ShortDebugString();
}
+ return ::testing::AssertionSuccess();
}
-template <>
-bool NanMismatch<complex64>(complex64 expected, complex64 actual,
- bool relaxed_nans) {
- return NanMismatch<float>(expected.real(), actual.real(), relaxed_nans) ||
- NanMismatch<float>(expected.imag(), actual.imag(), relaxed_nans);
-}
-
-template <>
-bool NanMismatch<half>(half expected, half actual, bool relaxed_nans) {
- return NanMismatch<float>(static_cast<float>(expected),
- static_cast<float>(actual), relaxed_nans);
-}
-
-// Converts the given floating-point value to a string.
-template <typename NativeT>
-string FpValueToString(NativeT value) {
- return Printf("%8.4g", static_cast<double>(value));
-}
-
-template <>
-string FpValueToString<complex64>(complex64 value) {
- return Printf("%8.4g + %8.4fi", value.real(), value.imag());
-}
-
-// Returns the absolute value of the given floating point value. This function
-// is used instead of std::abs directly in order to allow type-dependent
-// implementations for NearComparator.
-template <typename NativeT>
-float FpAbsoluteValue(NativeT value) {
- return std::abs(value);
-}
-
-template <>
-float FpAbsoluteValue(bfloat16 value) {
- return FpAbsoluteValue<float>(static_cast<float>(value));
-}
-
-template <>
-float FpAbsoluteValue(half value) {
- return FpAbsoluteValue<float>(static_cast<float>(value));
-}
-
-// Helper class for comparing floating-point literals within an error bound.
-template <typename NativeT>
-class NearComparator {
- public:
- // Compares the two array literals elementwise and returns an assertion
- // result. The assertion result is successful if all actual and expected
- // elements are within the given error bound. In case of error, the assertion
- // result contains a detailed error message in case of failure.
- static ::testing::AssertionResult Compare(const Literal& expected,
- const Literal& actual,
- ErrorSpec error,
- bool detailed_message) {
- NearComparator<NativeT> comparator(expected, actual, error,
- detailed_message);
- return comparator.Run();
- }
-
- private:
- // Data structure encapsulating metadata about a single element mismatch.
- struct Mismatch {
- NativeT actual;
- NativeT expected;
- float rel_error;
- float abs_error;
-
- // The linear index of the failure within the shape. This linear index is
- // from the 'actual' literal.
- int64 linear_index;
-
- bool operator<(const Mismatch& other) const {
- return rel_error < other.rel_error;
- }
-
- string ToString(const Shape& shape) const {
- return Printf(
- "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
- FpValueToString(actual).c_str(), FpValueToString(expected).c_str(),
- LiteralTestUtil::MultiIndexAsString(
- IndexUtil::LinearIndexToMultidimensionalIndex(shape,
- linear_index))
- .c_str(),
- rel_error, abs_error);
- }
- };
-
- explicit NearComparator(const Literal& expected, const Literal& actual,
- ErrorSpec error, bool detailed_message)
- : expected_(expected),
- actual_(actual),
- error_(error),
- detailed_message_(detailed_message),
- abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}),
- abs_error_buckets_(kErrorBucketBounds.size(), 0),
- rel_error_buckets_(kErrorBucketBounds.size(), 0) {}
-
- // Runs the comparison between expected and actual literals.
- ::testing::AssertionResult Run() {
- VLOG(1) << "expected:";
- XLA_VLOG_LINES(1, TruncateHugeLiteral(expected_));
- VLOG(1) << "actual:";
- XLA_VLOG_LINES(1, TruncateHugeLiteral(actual_));
-
- // If the shapes mismatch, we simply fail the expectation instead of
- // printing out data, as it's a type error rather than a value error.
- ::testing::AssertionResult equal_shapes =
- LiteralTestUtil::EqualShapes(expected_.shape(), actual_.shape());
- if (!equal_shapes) {
- return equal_shapes;
- }
- if (!ShapeUtil::IsArray(expected_.shape())) {
- return ::testing::AssertionFailure() << "Expected array shape";
- }
-
- mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED));
- mismatches_.PopulateWithValue(false);
-
- CompareLiterals();
-
- if (num_mismatches_ == 0) {
- return ::testing::AssertionSuccess();
- } else if (!VLOG_IS_ON(1)) {
- LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected_.shape())
- << " " << TruncateHugeLiteral(expected_);
- LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual_.shape())
- << " " << TruncateHugeLiteral(actual_);
- LOG(INFO) << "Dumping literals to temp files...";
- WriteLiteralToTempFile(expected_, "expected");
- WriteLiteralToTempFile(actual_, "actual");
- WriteLiteralToTempFile(mismatches_, "mismatches");
- }
- return ::testing::AssertionFailure() << ErrorMessage();
- }
-
- // Insert the given absolute value into the absolute value bucket vector. The
- // bounds of the buckets are given by kAbsValueBucketBounds.
- void UpdateAbsValueBucket(NativeT value, bool is_mismatch) {
- // Adjust the bucket containing the absolute values of the 'actual'
- // elements.
- const float abs_value = FpAbsoluteValue(value);
- for (int i = 0; i < abs_value_buckets_.size(); ++i) {
- if (i == abs_value_buckets_.size() - 1 ||
- (abs_value >= kAbsValueBucketBounds[i] &&
- abs_value < kAbsValueBucketBounds[i + 1])) {
- // The first value of the pair is the count of elements in the bucket,
- // the second is the count of mismatches in the bucket.
- abs_value_buckets_[i].first++;
- if (is_mismatch) {
- abs_value_buckets_[i].second++;
- }
- return;
- }
- }
- }
-
- // Insert the given error into the given error bucket vector.
- void UpdateErrorBucket(
- float error, tensorflow::gtl::MutableArraySlice<int64> error_buckets) {
- CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size());
- for (int i = 0; i < error_buckets.size(); ++i) {
- if (error >= kErrorBucketBounds[i]) {
- error_buckets[i]++;
- }
- }
- }
-
- // Compares the two given elements from the expected and actual literals at
- // the given literal_index and keeps track of various mismatch statistics.
- void CompareValues(NativeT expected, NativeT actual, int64 linear_index) {
- const bool is_nan_mismatch =
- NanMismatch(expected, actual, error_.relaxed_nans);
- float abs_error;
- float rel_error;
- if (actual == expected) {
- abs_error = 0;
- rel_error = 0;
- } else if (is_nan_mismatch) {
- num_nan_mismatches_++;
- // A nan mismatch is considered to have infinite error. rel_error is used
- // for sorting a std::set of the top mismatchs, and a nan value here will
- // result in undefined behavior because nan's do not satisfy the strict
- // weak ordering requirement of std containers.
- abs_error = std::numeric_limits<float>::infinity();
- rel_error = std::numeric_limits<float>::infinity();
- } else {
- abs_error = FpAbsoluteValue(actual - expected);
- rel_error = abs_error / FpAbsoluteValue(expected);
- }
- const bool is_abs_mismatch = abs_error > error_.abs;
- const bool is_rel_mismatch = rel_error > error_.rel;
- const bool is_mismatch =
- is_nan_mismatch || (is_abs_mismatch && is_rel_mismatch);
-
- // Update the error of the relative bucket only if the *absolute* error
- // bound is exceeded and vice versa.
- if (is_abs_mismatch) {
- num_abs_mismatches_++;
- UpdateErrorBucket(rel_error, &rel_error_buckets_);
- }
- if (is_rel_mismatch) {
- num_rel_mismatches_++;
- UpdateErrorBucket(abs_error, &abs_error_buckets_);
- }
-
- UpdateAbsValueBucket(actual, is_mismatch);
-
- if (!is_mismatch) {
- return;
- }
-
- num_mismatches_++;
-
- // Keep track of the kTopRelativeErrorCount relative error mismatches.
- if (top_rel_mismatches_.size() < kTopRelativeErrorCount ||
- rel_error > top_rel_mismatches_.begin()->rel_error) {
- Mismatch mismatch = {actual, expected, rel_error, abs_error,
- linear_index};
- top_rel_mismatches_.insert(mismatch);
- if (top_rel_mismatches_.size() > kTopRelativeErrorCount) {
- top_rel_mismatches_.erase(top_rel_mismatches_.begin());
- }
- }
-
- mismatches_.data<bool>()[linear_index] = true;
- }
-
- // Compares the two literals elementwise.
- void CompareLiterals() {
- // Fast path optimization for the case were layouts match.
- if (LayoutUtil::Equal(actual_.shape().layout(),
- expected_.shape().layout())) {
- tensorflow::gtl::ArraySlice<const NativeT> expected_data =
- expected_.data<NativeT>();
- tensorflow::gtl::ArraySlice<const NativeT> actual_data =
- actual_.data<NativeT>();
- const int64 len = expected_data.size();
- for (int64 i = 0; i < len; ++i) {
- CompareValues(expected_data[i], actual_data[i], i);
- }
- return;
- }
- std::vector<int64> multi_index(ShapeUtil::Rank(actual_.shape()), 0);
- CompareLiteralsSlow(0, &multi_index);
- }
-
- // Slow path for CompareLiterals when 'actual' and 'expected' literals have
- // different layouts. In this case, multidimensional indices are constructed
- // and indexed for each element.
- void CompareLiteralsSlow(int64 dimension, std::vector<int64>* multi_index) {
- if (dimension == multi_index->size()) {
- CompareValues(expected_.Get<NativeT>(*multi_index),
- actual_.Get<NativeT>(*multi_index),
- IndexUtil::MultidimensionalIndexToLinearIndex(
- actual_.shape(), *multi_index));
- } else {
- for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) {
- (*multi_index)[dimension] = i;
- CompareLiteralsSlow(dimension + 1, multi_index);
- }
- }
- }
-
- // Writes the given literal to a file in the test temporary directory.
- void WriteLiteralToTempFile(const Literal& literal, const string& name) {
- int64 now_usec = tensorflow::Env::Default()->NowMicros();
- string filename = tensorflow::io::JoinPath(
- tensorflow::testing::TmpDir(),
- Printf("tempfile-%s-%llx-%s", Hostname().c_str(), now_usec,
- name.c_str()));
- TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(),
- filename, literal.ToProto()));
- LOG(ERROR) << "wrote to " << name << " file: " << filename;
- }
-
- // Returns an error message string with a detailed breakdown of the
- // mismatches. Called after calling Run().
- string ErrorMessage() {
- string out;
- int64 element_count = ShapeUtil::ElementsIn(actual_.shape());
-
- auto percent_string = [](float a, float b) {
- float pct = b == 0.0 ? 0.0 : 100.0 * a / b;
- return Printf("%0.4f%%", pct);
- };
-
- Appendf(&out,
- "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound "
- "%g, rel bound %g\n",
- num_mismatches_,
- percent_string(num_mismatches_, element_count).c_str(),
- ShapeUtil::HumanString(actual_.shape()).c_str(),
- ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel);
- if (num_nan_mismatches_ > 0) {
- StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n");
- }
- Appendf(&out, "Top relative error mismatches:\n");
- for (auto it = top_rel_mismatches_.rbegin();
- it != top_rel_mismatches_.rend(); ++it) {
- StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n");
- }
-
- if (!detailed_message_) {
- return out;
- }
-
- StrAppend(&out, "Absolute magnitude breakdown of actual values:\n");
- CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size());
- for (int i = 0; i < abs_value_buckets_.size(); ++i) {
- const int64 bucket_size = abs_value_buckets_[i].first;
- const int64 bucket_mismatches = abs_value_buckets_[i].second;
- string mismatch_str = bucket_mismatches > 0
- ? Printf(", mismatches %lld", bucket_mismatches)
- : "";
- Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n",
- kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1],
- bucket_size, percent_string(bucket_size, element_count).c_str(),
- mismatch_str.c_str());
- }
-
- auto print_accum_buckets = [&](const string& header, int64 total,
- tensorflow::gtl::ArraySlice<int64> buckets) {
- StrAppend(&out, header, ":\n");
- Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0],
- total - buckets[0],
- percent_string(total - buckets[0], total).c_str());
- CHECK_EQ(buckets.size(), kErrorBucketBounds.size());
- for (int i = 0; i < kErrorBucketBounds.size(); ++i) {
- Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i],
- buckets[i], percent_string(buckets[i], total).c_str());
- }
- };
- Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n",
- error_.abs, num_abs_mismatches_,
- percent_string(num_abs_mismatches_, element_count).c_str());
- print_accum_buckets(
- "Relative error breakdown of elements exceeding abs error bound",
- num_abs_mismatches_, rel_error_buckets_);
- Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n",
- error_.rel, num_rel_mismatches_,
- percent_string(num_rel_mismatches_, element_count).c_str());
- print_accum_buckets(
- "Absolute error breakdown of elements exceeding rel error bound",
- num_rel_mismatches_, abs_error_buckets_);
- return out;
- }
-
- // 'actual' and 'expected' literals being compared.
- const Literal& expected_;
- const Literal& actual_;
-
- // The error bounds of the comparison.
- ErrorSpec error_;
-
- // Whether to include detailed breakdown of mismatches in the error message.
- bool detailed_message_;
-
- // Number of element element mismatches encountered so far.
- int64 num_mismatches_ = 0;
-
- // Number of elements with a nan mismatch.
- int64 num_nan_mismatches_ = 0;
-
- // Number of elements which exceed the absolute/relative error bound.
- int64 num_abs_mismatches_ = 0;
- int64 num_rel_mismatches_ = 0;
-
- // A Literal containing which elements did not match in the expected and
- // actual literals. mismatches_ contains PREDs and is of the same sizes as
- // the comparison literals.
- Literal mismatches_;
-
- // The number of mismatches to report in the output, sorted by relative error
- // magnitude.
- static constexpr int64 kTopRelativeErrorCount = 5;
-
- // The set of mismatches with the largest relative error. The size of this set
- // is bounded by kTopRelativeErrorCount.
- std::multiset<Mismatch> top_rel_mismatches_;
-
- // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the
- // bounds of these buckets. abs_value_buckets_ contains a pair for each
- // bucket: the element count and failure count.
- static constexpr std::array<float, 7> kAbsValueBucketBounds = {
- 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits<float>::infinity()};
- std::vector<std::pair<int64, int64>> abs_value_buckets_;
-
- // Buckets for relative and absolute errors. The relative error buckets only
- // contains those elements which exceed the *absolute* error bound, and vice
- // versa. This makes it easy to see the effect of adjusting the relative (or
- // absolute) error bound on the success of the comparison. kErrorBucketBounds
- // are the lower bounds of the buckets in both vectors. The error buckets are
- // a cumulative distribution so an error value may appear in more than one
- // bucket. For example an error value of 0.003 may appear in the buckets
- // bounded by 0.01, 0.1, and 1.0.
- static constexpr std::array<float, 5> kErrorBucketBounds = {0.0001, 0.001,
- 0.01, 0.1, 1};
- std::vector<int64> abs_error_buckets_;
- std::vector<int64> rel_error_buckets_;
-};
-
-template <typename NativeT>
-constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
-template <typename NativeT>
-constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
-
-// Helper function for comparing two literals for nearness. Handles tuple-shapes
-// via recursion. shape_index is the ShapeIndex of expected (or actual)
-// currently being compared.
-::testing::AssertionResult NearHelper(const Literal& expected,
- const Literal& actual,
- const ErrorSpec& error,
- bool detailed_message,
- const ShapeIndex& shape_index) {
- ::testing::AssertionResult err =
- LiteralTestUtil::EqualShapes(expected.shape(), actual.shape());
- if (!err) {
- return err;
- }
-
- if (ShapeUtil::IsTuple(expected.shape())) {
- for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
- const auto expected_element = LiteralView::Create(expected, {i});
- const auto actual_element = LiteralView::Create(actual, {i});
- ShapeIndex element_index = shape_index;
- element_index.push_back(i);
- ::testing::AssertionResult res =
- NearHelper(expected_element, actual_element, error, detailed_message,
- element_index);
- if (!res) {
- string err_message =
- Printf("\nArray at shape index %s%s",
- element_index.ToString().c_str(), res.message());
- if (err) {
- err = ::testing::AssertionFailure() << err_message;
- } else {
- err << err_message;
- }
- }
- }
- if (!err && shape_index.empty()) {
- // Emit a top-level error message containing the top-level shape in case
- // of mismatch.
- int64 total_elements = RecursiveElementCount(actual.shape());
- err = ::testing::AssertionFailure()
- << Printf("\nMismatches in shape %s (%lld elements):\n%s",
- ShapeUtil::HumanString(actual.shape()).c_str(),
- total_elements, err.message());
- }
- return err;
- }
-
- if (ShapeUtil::ElementIsFloating(expected.shape()) ||
- ShapeUtil::ElementIsComplex(expected.shape())) {
- switch (expected.shape().element_type()) {
- case BF16:
- return NearComparator<bfloat16>::Compare(expected, actual, error,
- detailed_message);
- break;
- case F16:
- return NearComparator<half>::Compare(expected, actual, error,
- detailed_message);
- break;
- case F32:
- return NearComparator<float>::Compare(expected, actual, error,
- detailed_message);
- break;
- case F64:
- return NearComparator<double>::Compare(expected, actual, error,
- detailed_message);
- break;
- case C64:
- return NearComparator<complex64>::Compare(expected, actual, error,
- detailed_message);
- break;
- default:
- LOG(FATAL) << "Unsupported primitive type in near comparator: "
- << PrimitiveType_Name(expected.shape().element_type())
- << ". Must be floating-point type.";
- }
- }
-
- // Non-floating point literal.
- return LiteralTestUtil::Equal(expected, actual);
+/* static */ ::testing::AssertionResult LiteralTestUtil::Equal(
+ const LiteralSlice& expected, const LiteralSlice& actual) {
+ return StatusToAssertion(literal_comparison::Equal(expected, actual));
}
-} // namespace
-
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(
- const Literal& expected, const Literal& actual, const ErrorSpec& error,
- bool detailed_message) {
- return NearHelper(expected, actual, error, detailed_message,
- /*shape_index=*/{});
-}
-
-/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected,
- const Literal& actual,
- const ErrorSpec& error,
- const string& message) {
- ::testing::AssertionResult res =
- Near(expected, actual, error, /*detailed_message=*/false);
- if (!res) {
- res << "Expected: " << TruncateHugeLiteral(expected) << "\n";
- res << "Actual: " << TruncateHugeLiteral(actual) << "\n";
- if (!message.empty()) {
- res << StrCat("\nmessage: ", message);
- }
- }
- EXPECT_TRUE(res);
+ const LiteralSlice& expected, const LiteralSlice& actual,
+ const ErrorSpec& error_spec, bool detailed_message) {
+ return StatusToAssertion(literal_comparison::Near(
+ expected, actual, error_spec, detailed_message, &OnMiscompare));
}
-/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
- const Literal& expected, const Literal& actual,
+/* static */ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
+ const LiteralSlice& expected, const LiteralSlice& actual,
const tensorflow::gtl::optional<ErrorSpec>& error) {
if (error.has_value()) {
VLOG(1) << "Expects near";
- return Near(expected, actual, *error);
+ return StatusToAssertion(literal_comparison::Near(
+ expected, actual, *error, /*detailed_message=*/false, &OnMiscompare));
}
VLOG(1) << "Expects equal";
- return Equal(expected, actual);
-}
-
-/*static*/ void LiteralTestUtil::ExpectNearOrEqual(
- const Literal& expected, const Literal& actual,
- const tensorflow::gtl::optional<ErrorSpec>& error) {
- EXPECT_TRUE(NearOrEqual(expected, actual, error));
-}
-
-/* static */ string LiteralTestUtil::MultiIndexAsString(
- tensorflow::gtl::ArraySlice<int64> multi_index) {
- return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
-}
-
-/* static */ std::unique_ptr<Literal> LiteralTestUtil::Reshape(
- tensorflow::gtl::ArraySlice<int64> new_dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major, const Literal& literal) {
- int64 new_num_elements = 1;
- for (int64 i = 0; i < new_dimensions.size(); ++i) {
- new_num_elements *= new_dimensions[i];
- }
- CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
- CHECK_EQ(new_dimensions.size(), minor_to_major.size());
-
- auto new_literal = MakeUnique<Literal>(
- ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
-
- // Create a new shape with the given minor-to-major layout. This shape is used
- // solely for converting linear address to multi-dimensional addresses when
- // writing elements to the new literal.
- Shape shape_with_layout = new_literal->shape();
- *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
-
- // Copy data into new literal, element-by-element.
- for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
- std::vector<int64> from_multi_index =
- IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
- std::vector<int64> to_multi_index =
- IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
- switch (literal.shape().element_type()) {
- case PRED:
- new_literal->Set<bool>(to_multi_index,
- literal.Get<bool>(from_multi_index));
- break;
- case U8:
- new_literal->Set<uint8>(to_multi_index,
- literal.Get<uint8>(from_multi_index));
- break;
- case U32:
- new_literal->Set<uint32>(to_multi_index,
- literal.Get<uint32>(from_multi_index));
- break;
- case S32:
- new_literal->Set<int32>(to_multi_index,
- literal.Get<int32>(from_multi_index));
- break;
- case U64:
- new_literal->Set<uint64>(to_multi_index,
- literal.Get<uint64>(from_multi_index));
- break;
- case S64:
- new_literal->Set<int64>(to_multi_index,
- literal.Get<int64>(from_multi_index));
- break;
- case F32:
- new_literal->Set<float>(to_multi_index,
- literal.Get<float>(from_multi_index));
- break;
- case F64:
- new_literal->Set<double>(to_multi_index,
- literal.Get<double>(from_multi_index));
- break;
- case C64:
- new_literal->Set<complex64>(to_multi_index,
- literal.Get<complex64>(from_multi_index));
- break;
- default:
- LOG(FATAL) << "Unhandled primitive element type: "
- << PrimitiveType_Name(literal.shape().element_type());
- }
- }
-
- return new_literal;
+ return StatusToAssertion(literal_comparison::Equal(expected, actual));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index a755568c0f..d1b8a6cf0b 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/error_spec.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -38,282 +39,190 @@ limitations under the License.
namespace xla {
-// Structure describing permissible absolute and relative error bounds.
-struct ErrorSpec {
- explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false)
- : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {}
-
- float abs; // Absolute error bound.
- float rel; // Relative error bound.
-
- // If relaxed_nans is true then any result is valid if we are expecting NaNs.
- // In effect, this allows the tested operation to produce incorrect results
- // for inputs outside its mathematical domain.
- bool relaxed_nans;
-};
-
// Utility class for making expectations/assertions related to XLA literals.
class LiteralTestUtil {
public:
// Asserts that the given shapes have the same rank, dimension sizes, and
// primitive types.
- static ::testing::AssertionResult EqualShapes(const Shape& expected,
- const Shape& actual);
- static void AssertEqualShapes(const Shape& expected, const Shape& actual);
+ static ::testing::AssertionResult EqualShapes(
+ const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT;
// Asserts that the provided shapes are equal as defined in AssertEqualShapes
// and that they have the same layout.
- static void AssertEqualShapesAndLayouts(const Shape& expected,
- const Shape& actual);
-
- // If the given literal's data type is bfloat16, converts it to a float
- // literal; otherwise, returns a copy of it. If the literal is a tuple,
- // recursively converts its elements.
- static std::unique_ptr<Literal> ConvertBF16ToF32(const Literal& bf16_literal);
-
- // If the given literal's data type is float, converts it to a bfloat16
- // literal; otherwise, returns a copy of it. If the literal is a tuple,
- // recursively converts its elements.
- static std::unique_ptr<Literal> ConvertF32ToBF16(const Literal& f32_literal);
-
- // Asserts that the expected and actual literals are (bitwise) equal for all
- // elements in the literal. Also, asserts that the rank, dimensions sizes, and
- // primitive type are equal.
- static ::testing::AssertionResult Equal(
- const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT;
+ static ::testing::AssertionResult EqualShapesAndLayouts(
+ const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT;
- // Expects that expected and actual are Equal.
- static void ExpectEqual(const Literal& expected, const Literal& actual,
- const string& message = "");
-
- // Expects that expected and actual are Not Equal.
- static void ExpectNotEqual(const Literal& expected, const Literal& actual);
+ static ::testing::AssertionResult Equal(const LiteralSlice& expected,
+ const LiteralSlice& actual)
+ TF_MUST_USE_RESULT;
// Asserts the given literal are (bitwise) equal to given expected values.
template <typename NativeT>
- static void ExpectR0Equal(NativeT expected, const Literal& actual);
+ static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual);
+
template <typename NativeT>
static void ExpectR1Equal(tensorflow::gtl::ArraySlice<NativeT> expected,
- const Literal& actual);
+ const LiteralSlice& actual);
template <typename NativeT>
static void ExpectR2Equal(
std::initializer_list<std::initializer_list<NativeT>> expected,
- const Literal& actual);
+ const LiteralSlice& actual);
+
template <typename NativeT>
static void ExpectR3Equal(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
expected,
- const Literal& actual);
+ const LiteralSlice& actual);
// Asserts the given literal are (bitwise) equal to given array.
template <typename NativeT>
static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
- const Literal& actual);
+ const LiteralSlice& actual);
template <typename NativeT>
static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
- const Literal& actual);
+ const LiteralSlice& actual);
template <typename NativeT>
static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
- const Literal& actual);
+ const LiteralSlice& actual);
- // Asserts that the expected and actual literals are within the given error
- // bound for all elements. Also, asserts that the rank, dimensions sizes, and
- // bounds are equivalent.
+ // Decorates literal_comparison::Near() with an AssertionResult return type.
//
- // Tuples are matched recursively. When comparing tensors of
- // non-floating-point type, checks for exact equality, ignoring the ErrorSpec.
- //
- // If the shape of the literals is neither a complex/floating-point tensor nor
- // a tuple which contains a complex/floating-point tensor, Near() is
- // equivalent to Equal(). We don't raise an error in this case, because we
- // want to allow callers to call Near() even if they have no preconceptions
- // about the shapes being compared.
- //
- // If detailed_message is true, then the error message in the assertion result
- // will contain a more detailed breakdown of mismatches.
+ // See comment on literal_comparison::Near().
static ::testing::AssertionResult Near(
- const Literal& expected, const Literal& actual, const ErrorSpec& error,
+ const LiteralSlice& expected, const LiteralSlice& actual,
+ const ErrorSpec& error_spec,
bool detailed_message = false) TF_MUST_USE_RESULT;
- // Expects expected and actual to be Near with the given error.
- static void ExpectNear(const Literal& expected, const Literal& actual,
- const ErrorSpec& error, const string& message = "");
-
// Asserts the given literal are within the given error bound of the given
// expected values. Only supported for floating point values.
template <typename NativeT>
- static void ExpectR0Near(NativeT expected, const Literal& actual,
+ static void ExpectR0Near(NativeT expected, const LiteralSlice& actual,
const ErrorSpec& error);
+
template <typename NativeT>
static void ExpectR1Near(tensorflow::gtl::ArraySlice<NativeT> expected,
- const Literal& actual, const ErrorSpec& error);
+ const LiteralSlice& actual, const ErrorSpec& error);
+
template <typename NativeT>
static void ExpectR2Near(
std::initializer_list<std::initializer_list<NativeT>> expected,
- const Literal& actual, const ErrorSpec& error);
+ const LiteralSlice& actual, const ErrorSpec& error);
+
template <typename NativeT>
static void ExpectR3Near(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
expected,
- const Literal& actual, const ErrorSpec& error);
+ const LiteralSlice& actual, const ErrorSpec& error);
+
template <typename NativeT>
static void ExpectR4Near(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
expected,
- const Literal& actual, const ErrorSpec& error);
+ const LiteralSlice& actual, const ErrorSpec& error);
// Asserts the given literal are within the given error bound to the given
// array. Only supported for floating point values.
template <typename NativeT>
static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
- const Literal& actual,
+ const LiteralSlice& actual,
const ErrorSpec& error);
+
template <typename NativeT>
static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
- const Literal& actual,
+ const LiteralSlice& actual,
const ErrorSpec& error);
+
template <typename NativeT>
static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
- const Literal& actual,
+ const LiteralSlice& actual,
const ErrorSpec& error);
// If the error spec is given, returns whether the expected and the actual are
// within the error bound; otherwise, returns whether they are equal. Tuples
// will be compared recursively.
static ::testing::AssertionResult NearOrEqual(
- const Literal& expected, const Literal& actual,
+ const LiteralSlice& expected, const LiteralSlice& actual,
const tensorflow::gtl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
- // If the error spec is given, expects the expected and the actual to be near;
- // otherwise, expects them to be equal. Tuples will be compared recursively.
- static void ExpectNearOrEqual(
- const Literal& expected, const Literal& actual,
- const tensorflow::gtl::optional<ErrorSpec>& error);
-
- // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
- // be returned for a 2-dimensional index with dimension 0 index equal to 7,
- // dimension 1 equal to 8.
- static string MultiIndexAsString(
- tensorflow::gtl::ArraySlice<int64> multi_index);
-
- // Creates a literal with a new shape with the given new dimensions using the
- // data in the given input literal. For reshaping purposes the (flat) data
- // buffer of the input literal is assumed to have the given minor_to_major
- // layout order.
- static std::unique_ptr<Literal> Reshape(
- tensorflow::gtl::ArraySlice<int64> new_dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major,
- const Literal& literal);
-
- // Creates a literal with the supplied shape, and uses the provided value
- // generator to populate the literal's values.
- // Returns the new literal object, or an error Status if failed.
- template <
- PrimitiveType type,
- typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape,
- const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
-
- // Creates a literal with the supplied shape, and initializes the literal
- // values using a normal distribution with given mean and stddev standard
- // deviation, and using the engine as entropy generator.
- // Returns the new literal object, or an error Status if failed.
- template <
- PrimitiveType type, typename E,
- typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape, E* engine, T mean, T stddev);
-
- // Creates a literal with the supplied shape, and initializes the literal
- // values using a normal distribution with given mean and stddev standard
- // deviation.
- // Returns the new literal object, or an error Status if failed.
- template <
- PrimitiveType type,
- typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape, T mean, T stddev);
-
private:
TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
};
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
- const Literal& actual) {
- ExpectEqual(*Literal::CreateR0<NativeT>(expected), actual);
+ const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(*Literal::CreateR0<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Equal(
- tensorflow::gtl::ArraySlice<NativeT> expected, const Literal& actual) {
- ExpectEqual(*Literal::CreateR1<NativeT>(expected), actual);
+ tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(*Literal::CreateR1<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Equal(
std::initializer_list<std::initializer_list<NativeT>> expected,
- const Literal& actual) {
- ExpectEqual(*Literal::CreateR2<NativeT>(expected), actual);
+ const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(*Literal::CreateR2<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3Equal(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
- const Literal& actual) {
- ExpectEqual(*Literal::CreateR3<NativeT>(expected), actual);
+ const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(*Literal::CreateR3<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
- const Array2D<NativeT>& expected, const Literal& actual) {
- ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual);
+ const Array2D<NativeT>& expected, const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
- const Array3D<NativeT>& expected, const Literal& actual) {
- ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual);
+ const Array3D<NativeT>& expected, const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
- const Array4D<NativeT>& expected, const Literal& actual) {
- ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual);
+ const Array4D<NativeT>& expected, const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
- const Literal& actual,
+ const LiteralSlice& actual,
const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR0<NativeT>(expected), actual, error);
+ EXPECT_TRUE(Near(*Literal::CreateR0<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Near(
- tensorflow::gtl::ArraySlice<NativeT> expected, const Literal& actual,
+ tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR1<NativeT>(expected), actual, error);
+ EXPECT_TRUE(Near(*Literal::CreateR1<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Near(
std::initializer_list<std::initializer_list<NativeT>> expected,
- const Literal& actual, const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR2<NativeT>(expected), actual, error);
+ const LiteralSlice& actual, const ErrorSpec& error) {
+ EXPECT_TRUE(Near(*Literal::CreateR2<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3Near(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
- const Literal& actual, const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR3<NativeT>(expected), actual, error);
+ const LiteralSlice& actual, const ErrorSpec& error) {
+ EXPECT_TRUE(Near(*Literal::CreateR3<NativeT>(expected), actual, error));
}
template <typename NativeT>
@@ -321,63 +230,29 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
expected,
- const Literal& actual, const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR4<NativeT>(expected), actual, error);
+ const LiteralSlice& actual, const ErrorSpec& error) {
+ EXPECT_TRUE(Near(*Literal::CreateR4<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2NearArray2D(
- const Array2D<NativeT>& expected, const Literal& actual,
+ const Array2D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error);
+ EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3NearArray3D(
- const Array3D<NativeT>& expected, const Literal& actual,
+ const Array3D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error);
+ EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4NearArray4D(
- const Array4D<NativeT>& expected, const Literal& actual,
+ const Array4D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error);
-}
-
-template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralTestUtil::CreateRandomLiteral(
- const Shape& shape,
- const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
- using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
- TF_RET_CHECK(shape.element_type() == type);
- std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
- TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
- [&](tensorflow::gtl::ArraySlice<int64> indexes) {
- return generator(indexes);
- }));
- return std::move(literal);
-}
-
-template <PrimitiveType type, typename E, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
- T stddev) {
- using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
- std::normal_distribution<NativeT> generator(mean, stddev);
- return CreateRandomLiteral<type, NativeT>(
- shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
- return generator(*engine);
- });
-}
-
-template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
- std::minstd_rand0 engine;
- return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
+ EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
index 9d619a77c7..bbac7285ae 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
@@ -34,7 +34,7 @@ TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
std::unique_ptr<Literal> literal = Literal::MakeTuple({
Literal::CreateR0<int32>(42).get(), Literal::CreateR0<int32>(64).get(),
});
- LiteralTestUtil::ExpectEqual(*literal, *literal);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal));
}
TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
@@ -97,6 +97,15 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
}
}
+TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
+ auto expected = Literal::CreateR1<int32>({1, 2, 3});
+ auto actual = Literal::CreateR1<int32>({4, 5, 6});
+ ::testing::AssertionResult result =
+ LiteralTestUtil::Equal(*expected, *actual);
+ EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}"));
+ EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}"));
+}
+
TEST(LiteralTestUtilTest, NearComparatorR1) {
auto a =
Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 44c6811df8..96858c00d6 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -210,12 +210,12 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0}));
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>(
{{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralView::Create(*result_literal, {1}));
+ LiteralSlice(*result_literal, {1}));
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {2}));
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
@@ -239,16 +239,16 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1}));
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1}));
LiteralTestUtil::ExpectR2Equal<float>(
{{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralView::Create(*result_literal, {0, 0}));
+ LiteralSlice(*result_literal, {0, 0}));
LiteralTestUtil::ExpectR2Equal<float>(
{{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralView::Create(*result_literal, {0, 1}));
+ LiteralSlice(*result_literal, {0, 1}));
LiteralTestUtil::ExpectR2Equal<float>(
{{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralView::Create(*result_literal, {0, 2}));
+ LiteralSlice(*result_literal, {0, 2}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
@@ -274,9 +274,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0}));
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1}));
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
@@ -321,9 +321,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>(
{{56.0f, 46.0f}, {36.0f, 26.0f}},
- LiteralView::Create(*result_literal, {0}));
+ LiteralSlice(*result_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>(
- {40.0f, 71.0f, 117.0f}, LiteralView::Create(*result_literal, {1}));
+ {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
@@ -361,9 +361,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>(
- {{-1.0, -2.0}, {-3.0, -4}}, LiteralView::Create(*result_literal, {0}));
+ {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>(
- {264.0, 73.0, 133.0}, LiteralView::Create(*result_literal, {1}));
+ {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
@@ -391,16 +391,16 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
std::unique_ptr<Literal> result_0_literal = ShapedBufferToLiteral(result_0);
LiteralTestUtil::ExpectR2Equal<float>(
{{-1.0, -2.0}, {-3.0, -4.0}},
- LiteralView::Create(*result_0_literal, {0}));
+ LiteralSlice(*result_0_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>(
- {{22.0, 6.0}, {8.0, 10}}, LiteralView::Create(*result_0_literal, {1}));
+ {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1}));
ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
std::unique_ptr<Literal> result_1_literal = ShapedBufferToLiteral(result_1);
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0, 2.0}, {3.0, 4.0}}, LiteralView::Create(*result_1_literal, {0}));
+ {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>(
- {{44.0, 12.0}, {16.0, 20}}, LiteralView::Create(*result_1_literal, {1}));
+ {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
@@ -447,7 +447,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
for (int i = 0; i < kElementCount; ++i) {
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f * i, 0.0f}, LiteralView::Create(*result_literal, {i}),
+ {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}),
error_spec_);
}
}
@@ -502,7 +502,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
for (int i = 0; i < kFanout; ++i) {
for (int j = 0; j < kFanout; ++j) {
LiteralTestUtil::ExpectR0Near<float>(
- i + j + i * kFanout + j, LiteralView::Create(*result_literal, {i, j}),
+ i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}),
error_spec_);
}
}
@@ -548,7 +548,7 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
index.push_back(0);
}
LiteralTestUtil::ExpectR0Equal<float>(
- 165.0, LiteralView::Create(*result_literal, index));
+ 165.0, LiteralSlice(*result_literal, index));
}
XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
@@ -754,9 +754,9 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
std::unique_ptr<Literal> tuple_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR1Equal<float>(
- {2.0f, 4.0f, 6.0f}, LiteralView::Create(*tuple_literal, {0}));
+ {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>(
- {1.0f, 2.0f, 3.0f}, LiteralView::Create(*tuple_literal, {1}));
+ {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index e859b3059e..88797a7d0a 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -35,9 +35,9 @@ namespace xla {
/* static */ TestAllocator* LocalClientTestBase::allocator_;
-StatusOr<se::DeviceMemoryBase> TestAllocator::Allocate(int device_ordinal,
- uint64 size,
- bool retry_on_failure) {
+StatusOr<OwningDeviceMemory> TestAllocator::Allocate(int device_ordinal,
+ uint64 size,
+ bool retry_on_failure) {
VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")";
{
tensorflow::mutex_lock lock(count_mutex_);
@@ -48,8 +48,7 @@ StatusOr<se::DeviceMemoryBase> TestAllocator::Allocate(int device_ordinal,
retry_on_failure);
}
-tensorflow::Status TestAllocator::Deallocate(int device_ordinal,
- se::DeviceMemoryBase* mem) {
+Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
VLOG(2) << "Deallocate(" << device_ordinal << ")";
{
tensorflow::mutex_lock lock(count_mutex_);
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h
index 3bbb760c80..258226523d 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.h
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.h
@@ -46,10 +46,9 @@ class TestAllocator : public StreamExecutorMemoryAllocator {
platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) {
}
- StatusOr<se::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size,
- bool retry_on_failure) override;
- tensorflow::Status Deallocate(int device_ordinal,
- se::DeviceMemoryBase* mem) override;
+ StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
+ bool retry_on_failure) override;
+ Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
// Return the number of allocations that have been performed.
int64 allocation_count() const;
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index 464cc01214..27fd36e06a 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index 0a603f4954..b745522ff0 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include <new>
#include <utility>
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
@@ -108,7 +107,7 @@ class MultiOutputFusionTest : public HloTestBase {
expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
auto actual = ExecuteAndTransfer(
std::move(hlo_module), {Literal::CreateR0<float>(-9.0f).get(), &arg1});
- LiteralTestUtil::ExpectNear(expect, *actual, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
}
void RunTest1D(bool manual_fusion, int size) {
@@ -168,7 +167,7 @@ class MultiOutputFusionTest : public HloTestBase {
Literal expect = std::move(*Literal::CreateR1<float>({size * 1.5f * 3.5f}));
auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
- LiteralTestUtil::ExpectNear(expect, *actual, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
}
};
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index 97dab860c0..838f1b4e2f 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
@@ -161,7 +160,7 @@ XLA_TEST_F(ParamsTest, MissingParameter) {
auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2");
auto computation_status = builder.Build();
- ASSERT_NE(computation_status.status(), tensorflow::Status::OK());
+ ASSERT_NE(computation_status.status(), Status::OK());
}
XLA_TEST_F(ParamsTest, UnusedParameter) {
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 29a4f75001..1a2de6937c 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -273,11 +273,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
&execution_options_));
}
- LiteralTestUtil::ExpectEqual(*result1, *result2);
- LiteralTestUtil::ExpectEqual(*result1, *result3);
- LiteralTestUtil::ExpectNotEqual(*result1, *result4);
- LiteralTestUtil::ExpectNotEqual(*result4, *result5);
- LiteralTestUtil::ExpectNotEqual(*result5, *result6);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3));
+ EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4));
+ EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5));
+ EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6));
}
XLA_TEST_F(PrngTest, TenValuesN01) {
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index bcc05c2d41..d671d40456 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index d7462d581b..a4580cd71d 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -656,9 +656,9 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
std::unique_ptr<Literal> expected =
Literal::CreateR2FromArray2D<float>(expected_array);
if (use_bfloat16()) {
- expected = LiteralTestUtil::ConvertF32ToBF16(*expected);
+ expected = Literal::ConvertF32ToBF16(*expected);
}
- LiteralTestUtil::ExpectEqual(*expected, *actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
}
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
@@ -731,7 +731,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
std::unique_ptr<Literal> expected =
- LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal);
+ Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal);
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
zero_error_spec_);
}
@@ -753,7 +753,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
std::unique_ptr<Literal> expected =
- LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal);
+ Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal);
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
zero_error_spec_);
}
@@ -817,7 +817,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
// Since the reshape is a no-op, verify that it does not change the underlying
// data.
if (use_bfloat16()) {
- auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal);
+ auto expected = Literal::ConvertF32ToBF16(*input_literal);
EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
} else {
EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
@@ -886,7 +886,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
+ Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
@@ -915,7 +915,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
+ Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
@@ -944,7 +944,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
+ Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
@@ -974,7 +974,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
+ Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
@@ -1003,7 +1003,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal)
+ Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal)
->Relayout(input_literal->shape().layout());
// Specify the requested output shape explicitly to ensure that this reshape
diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
index 8cbfcc6f5c..7cfca781ac 100644
--- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
@@ -100,7 +100,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
- LiteralTestUtil::ExpectEqual(*round_tripped, *actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
}
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
@@ -135,7 +135,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
- LiteralTestUtil::ExpectEqual(*round_tripped, *actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
index 32db45f8a6..f334a8c131 100644
--- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
@@ -41,7 +41,7 @@ class RoundTripTransferTest : public ClientLibraryTestBase {
client_->TransferToServer(original).ConsumeValueOrDie();
std::unique_ptr<Literal> result =
client_->Transfer(*data).ConsumeValueOrDie();
- LiteralTestUtil::ExpectEqual(original, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(original, *result));
}
};
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index f35bc43a49..308d3fc78a 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -390,7 +390,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
&execution_options_)
.ConsumeValueOrDie();
auto expected_literal = Literal::CreateR0<uint32>(dividend / divisor);
- LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
}
}
@@ -431,7 +431,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
&execution_options_)
.ConsumeValueOrDie();
auto expected_literal = Literal::CreateR0<uint32>(dividend % divisor);
- LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
}
}
diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
index e2067bc1b8..0063e7ad41 100644
--- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc
+++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
@@ -175,7 +175,7 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) {
transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer));
- LiteralTestUtil::ExpectEqual(*literal, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
@@ -189,7 +189,7 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer));
- LiteralTestUtil::ExpectEqual(*literal, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
@@ -209,7 +209,7 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer));
- LiteralTestUtil::ExpectEqual(*literal, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
@@ -224,7 +224,7 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer));
- LiteralTestUtil::ExpectEqual(*literal, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
@@ -243,7 +243,7 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer));
- LiteralTestUtil::ExpectEqual(*literal, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index 5c287bac6a..7552224f10 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -17,7 +17,6 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
@@ -515,7 +514,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
class TupleHloTest : public HloTestBase {};
// Disabled on the interpreter because bitcast doesn't exist on the interpreter.
-TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) {
+XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) {
const char* testcase = R"(
HloModule m
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 7944b5132f..3c9a01653c 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -84,8 +84,8 @@ Status ParseOneProfileOutputLine(
string match_percentage = "\\d+\\.\\d\\d%";
string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)";
string match_usecs = "([0-9.]+) usec";
- string match_flops = "([^ ]+)";
- string match_trops = "([^ ]+)";
+ string match_flops = "([^ ]*)";
+ string match_trops = "([^ ]*)";
string match_bytes_per_sec = "([0-9.TGMKi]+)B/s";
string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle";
diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc
index 6e3061b78a..373c0d2d8d 100644
--- a/tensorflow/compiler/xla/text_literal_writer.cc
+++ b/tensorflow/compiler/xla/text_literal_writer.cc
@@ -30,7 +30,7 @@ limitations under the License.
namespace xla {
-/* static */ tensorflow::Status TextLiteralWriter::WriteToPath(
+/* static */ Status TextLiteralWriter::WriteToPath(
const Literal& literal, tensorflow::StringPiece path) {
std::unique_ptr<tensorflow::WritableFile> f;
auto s = tensorflow::Env::Default()->NewWritableFile(std::string(path), &f);
@@ -43,7 +43,7 @@ namespace xla {
return s;
}
- tensorflow::Status status;
+ Status status;
tensorflow::WritableFile* f_ptr = f.get();
literal.EachCellAsString(
[f_ptr, &status](tensorflow::gtl::ArraySlice<int64> indices,
diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h
index 7375493f43..0a1235b5e0 100644
--- a/tensorflow/compiler/xla/text_literal_writer.h
+++ b/tensorflow/compiler/xla/text_literal_writer.h
@@ -37,8 +37,8 @@ namespace xla {
// This should be readable by xla::TextLiteralReader.
class TextLiteralWriter {
public:
- static tensorflow::Status WriteToPath(const Literal& literal,
- tensorflow::StringPiece path);
+ static Status WriteToPath(const Literal& literal,
+ tensorflow::StringPiece path);
private:
TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter);
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 156a06c596..d0e7af8844 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -481,10 +481,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kFloor:
case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kReal:
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index e100d8cda1..131aded95a 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -938,13 +938,13 @@ INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest,
TEST_F(HloParserTest, Empty) {
const string original = "";
auto result = Parse(original);
- EXPECT_NE(tensorflow::Status::OK(), result.status());
+ EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, Garbage) {
const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
auto result = Parse(original);
- EXPECT_NE(tensorflow::Status::OK(), result.status());
+ EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, WrongOpcode) {
@@ -958,7 +958,7 @@ ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
)";
auto result = Parse(original);
- EXPECT_NE(tensorflow::Status::OK(), result.status());
+ EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, WrongShape) {
@@ -970,7 +970,7 @@ ENTRY %blabla (x: g32[]) -> g32[] {
)";
auto result = Parse(original);
- EXPECT_NE(tensorflow::Status::OK(), result.status());
+ EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, WrongOperandsSize) {
@@ -983,7 +983,7 @@ ENTRY %blabla (x: f32[]) -> pred[] {
)";
auto result = Parse(original);
- EXPECT_NE(tensorflow::Status::OK(), result.status());
+ EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, OperandNotFound) {
@@ -994,7 +994,7 @@ ENTRY %blabla (x: f32[]) -> pred[] {
}
)";
auto result = Parse(original);
- EXPECT_NE(tensorflow::Status::OK(), result.status());
+ EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, MoreConstants) {
@@ -1036,7 +1036,7 @@ ENTRY %some_2 () -> f32[2] {
)";
auto result = Parse(original);
- EXPECT_NE(tensorflow::Status::OK(), result.status());
+ EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"expects nested array in rank 1, but sees larger");
}
@@ -1050,7 +1050,7 @@ ENTRY %some_2x3 () -> f32[2,3] {
)";
auto result = Parse(original);
- EXPECT_NE(tensorflow::Status::OK(), result.status());
+ EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"expects nested array in rank 2, but sees 1");
}
@@ -1064,7 +1064,7 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] {
)";
auto result = Parse(original);
- EXPECT_NE(tensorflow::Status::OK(), result.status());
+ EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"expects 3 elements in the [0]th element");
}
@@ -1079,7 +1079,7 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] {
)";
auto result = Parse(original);
- EXPECT_NE(tensorflow::Status::OK(), result.status());
+ EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"is out of range for literal's primitive type F16");
}
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 750d72d797..b895ac045c 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -814,6 +814,12 @@ enum UnaryOperation {
// Elementwise, computes clz(x).
UNOP_CLZ = 17;
+
+ // Elementwise, computes exp(x)-1.
+ UNOP_EXPM1 = 18;
+
+ // Elementwise, computes log(x+1).
+ UNOP_LOG1P = 19;
}
message UnaryOpRequest {
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index abdbdb4cd2..0f9c80404a 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -71,6 +71,7 @@ py_library(
"//tensorflow/contrib/memory_stats:memory_stats_py",
"//tensorflow/contrib/meta_graph_transform",
"//tensorflow/contrib/metrics:metrics_py",
+ "//tensorflow/contrib/mixed_precision:mixed_precision",
"//tensorflow/contrib/model_pruning",
"//tensorflow/contrib/nccl:nccl_py",
"//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index 9f5459f41d..9aad772f0a 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -60,6 +60,7 @@ from tensorflow.contrib import lookup
from tensorflow.contrib import losses
from tensorflow.contrib import memory_stats
from tensorflow.contrib import metrics
+from tensorflow.contrib import mixed_precision
from tensorflow.contrib import model_pruning
from tensorflow.contrib import nccl
from tensorflow.contrib import nn
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py
index 1be1c96dd3..35877224b8 100644
--- a/tensorflow/contrib/autograph/converters/break_statements.py
+++ b/tensorflow/contrib/autograph/converters/break_statements.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import gast
+
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer
@@ -52,8 +54,13 @@ class BreakStatementTransformer(transformer.Base):
def _guard_if_present(self, block, var_name):
"""Prevents the block from executing if var_name is set."""
+
+ # If we don't have statements that immediately depend on the break
+ # we still need to make sure that the break variable remains
+ # used, in case the break becomes useful in later stages of transformation.
+ # Not having this broke the break_in_inner_loop test.
if not block:
- return block
+ block = [gast.Pass()]
template = """
if not var_name:
block
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py
index 935a2786db..d7ddbe8a04 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/contrib/autograph/converters/control_flow.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Handles control flow statements: while, if."""
+"""Handles control flow statements: while, for, if."""
from __future__ import absolute_import
from __future__ import division
@@ -25,6 +25,7 @@ from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.contrib.autograph.pyct.static_analysis import cfg
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
@@ -47,9 +48,6 @@ class SymbolNamer(object):
class ControlFlowTransformer(transformer.Base):
"""Transforms control flow structures like loops an conditionals."""
- def __init__(self, context):
- super(ControlFlowTransformer, self).__init__(context)
-
def _create_cond_branch(self, body_name, aliased_orig_names,
aliased_new_names, body, returns):
if aliased_orig_names:
@@ -98,30 +96,63 @@ class ControlFlowTransformer(transformer.Base):
body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)
-
- if body_scope.created - orelse_scope.created:
- raise ValueError(
- 'The if branch creates new symbols that the else branch does not.')
- if orelse_scope.created - body_scope.created:
- raise ValueError(
- 'The else branch creates new symbols that the if branch does not.')
-
- modified = tuple(body_scope.modified | orelse_scope.modified)
- all_referenced = body_scope.referenced | orelse_scope.referenced
+ body_defs = body_scope.created | body_scope.modified
+ orelse_defs = orelse_scope.created | orelse_scope.modified
+ live = anno.getanno(node, 'live_out')
+
+ # We'll need to check if we're closing over variables that are defined
+ # elsewhere in the function
+ # NOTE: we can only detect syntactic closure in the scope
+ # of the code passed in. If the AutoGraph'd function itself closes
+ # over other variables, this analysis won't take that into account.
+ defined = anno.getanno(node, 'defined_in')
+
+ # We only need to return variables that are
+ # - modified by one or both branches
+ # - live (or has a live parent) at the end of the conditional
+ modified = []
+ for def_ in body_defs | orelse_defs:
+ def_with_parents = set((def_,)) | def_.support_set
+ if live & def_with_parents:
+ modified.append(def_)
+
+ # We need to check if live created variables are balanced
+ # in both branches
+ created = live & (body_scope.created | orelse_scope.created)
+
+ # The if statement is illegal if there are variables that are created,
+ # that are also live, but both branches don't create them.
+ if created:
+ if created != (body_scope.created & live):
+ raise ValueError(
+ 'The main branch does not create all live symbols that the else '
+ 'branch does.')
+ if created != (orelse_scope.created & live):
+ raise ValueError(
+ 'The else branch does not create all live symbols that the main '
+ 'branch does.')
# Alias the closure variables inside the conditional functions
# to avoid errors caused by the local variables created in the branch
# functions.
- need_alias = (
- (body_scope.modified | orelse_scope.modified) -
- (body_scope.created | orelse_scope.created))
- aliased_orig_names = tuple(need_alias)
- aliased_new_names = tuple(
- self.context.namer.new_symbol(s.ssf(), all_referenced)
- for s in aliased_orig_names)
- alias_map = dict(zip(aliased_orig_names, aliased_new_names))
- node_body = ast_util.rename_symbols(node.body, alias_map)
- node_orelse = ast_util.rename_symbols(node.orelse, alias_map)
+ # We will alias variables independently for body and orelse scope,
+ # because different branches might write different variables.
+ aliased_body_orig_names = tuple(body_scope.modified - body_scope.created)
+ aliased_orelse_orig_names = tuple(orelse_scope.modified -
+ orelse_scope.created)
+ aliased_body_new_names = tuple(
+ self.context.namer.new_symbol(s.ssf(), body_scope.referenced)
+ for s in aliased_body_orig_names)
+ aliased_orelse_new_names = tuple(
+ self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced)
+ for s in aliased_orelse_orig_names)
+
+ alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
+ alias_orelse_map = dict(
+ zip(aliased_orelse_orig_names, aliased_orelse_new_names))
+
+ node_body = ast_util.rename_symbols(node.body, alias_body_map)
+ node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)
if not modified:
# When the cond would return no value, we leave the cond called without
@@ -134,26 +165,47 @@ class ControlFlowTransformer(transformer.Base):
else:
results = gast.Tuple([s.ast() for s in modified], None)
- body_name = self.context.namer.new_symbol('if_true', all_referenced)
- orelse_name = self.context.namer.new_symbol('if_false', all_referenced)
+ body_name = self.context.namer.new_symbol('if_true', body_scope.referenced)
+ orelse_name = self.context.namer.new_symbol('if_false',
+ orelse_scope.referenced)
if modified:
- body_returns = tuple(
- alias_map[s] if s in aliased_orig_names else s for s in modified)
+
+ def build_returns(aliased_names, alias_map, scope):
+ """Builds list of return variables for a branch of a conditional."""
+ returns = []
+ for s in modified:
+ if s in aliased_names:
+ returns.append(alias_map[s])
+ else:
+ if s not in scope.created | defined:
+ raise ValueError(
+ 'Attempting to return variable "%s" from the true branch of '
+ 'a conditional, but it was not closed over, or created in '
+ 'this branch.' % str(s))
+ else:
+ returns.append(s)
+ return tuple(returns)
+
+ body_returns = build_returns(aliased_body_orig_names, alias_body_map,
+ body_scope)
+ orelse_returns = build_returns(aliased_orelse_orig_names,
+ alias_orelse_map, orelse_scope)
+
else:
- body_returns = templates.replace('tf.ones(())')[0].value
+ body_returns = orelse_returns = templates.replace('tf.ones(())')[0].value
body_def = self._create_cond_branch(
body_name,
- aliased_orig_names=tuple(aliased_orig_names),
- aliased_new_names=tuple(aliased_new_names),
+ aliased_orig_names=tuple(aliased_body_orig_names),
+ aliased_new_names=tuple(aliased_body_new_names),
body=node_body,
returns=body_returns)
orelse_def = self._create_cond_branch(
orelse_name,
- aliased_orig_names=tuple(aliased_orig_names),
- aliased_new_names=tuple(aliased_new_names),
+ aliased_orig_names=tuple(aliased_orelse_orig_names),
+ aliased_new_names=tuple(aliased_orelse_new_names),
body=node_orelse,
- returns=body_returns)
+ returns=orelse_returns)
cond_expr = self._create_cond_expr(results, node.test, body_name,
orelse_name)
@@ -284,6 +336,7 @@ class ControlFlowTransformer(transformer.Base):
def transform(node, context):
- t = ControlFlowTransformer(context)
- node = t.visit(node)
+ cfg.run_analyses(node, cfg.Liveness(context))
+ cfg.run_analyses(node, cfg.Defined(context))
+ node = ControlFlowTransformer(context).visit(node)
return node
diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py
index c5610b16b4..1a863590f9 100644
--- a/tensorflow/contrib/autograph/converters/control_flow_test.py
+++ b/tensorflow/contrib/autograph/converters/control_flow_test.py
@@ -22,6 +22,7 @@ from tensorflow.contrib.autograph.converters import control_flow
from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
@@ -95,6 +96,91 @@ class ControlFlowTest(converter_test_base.TestCase):
with self.test_session() as sess:
self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
+ def test_imbalanced_aliasing(self):
+
+ def test_fn(n):
+ if n > 0:
+ n = 3
+ return n
+
+ node = self.parse_and_analyze(test_fn, {})
+ node = control_flow.transform(node, self.ctx)
+
+ with self.compiled(node, control_flow_ops.cond) as result:
+ with self.test_session() as sess:
+ self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(2))))
+ self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3))))
+
+ def test_ignore_unread_variable(self):
+
+ def test_fn(n):
+ b = 3 # pylint: disable=unused-variable
+ if n > 0:
+ b = 4
+ return n
+
+ node = self.parse_and_analyze(test_fn, {})
+ node = control_flow.transform(node, self.ctx)
+
+ with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result:
+ with self.test_session() as sess:
+ self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(3))))
+ self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3))))
+
+ def test_handle_temp_variable(self):
+
+ def test_fn_using_temp(x, y, w):
+ if x < y:
+ z = x + y
+ else:
+ w = 2
+ tmp = w
+ z = x - tmp
+ return z, w
+
+ node = self.parse_and_analyze(test_fn_using_temp, {})
+ node = control_flow.transform(node, self.ctx)
+
+ with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result:
+ with self.test_session() as sess:
+ z, w = sess.run(
+ result.test_fn_using_temp(
+ constant_op.constant(-3), constant_op.constant(3),
+ constant_op.constant(3)))
+ self.assertEqual(0, z)
+ self.assertEqual(3, w)
+ z, w = sess.run(
+ result.test_fn_using_temp(
+ constant_op.constant(3), constant_op.constant(-3),
+ constant_op.constant(3)))
+ self.assertEqual(1, z)
+ self.assertEqual(2, w)
+
+ def test_fn_ignoring_temp(x, y, w):
+ if x < y:
+ z = x + y
+ else:
+ w = 2
+ tmp = w
+ z = x - tmp
+ return z
+
+ node = self.parse_and_analyze(test_fn_ignoring_temp, {})
+ node = control_flow.transform(node, self.ctx)
+
+ with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result:
+ with self.test_session() as sess:
+ z = sess.run(
+ result.test_fn_ignoring_temp(
+ constant_op.constant(-3), constant_op.constant(3),
+ constant_op.constant(3)))
+ self.assertEqual(0, z)
+ z = sess.run(
+ result.test_fn_ignoring_temp(
+ constant_op.constant(3), constant_op.constant(-3),
+ constant_op.constant(3)))
+ self.assertEqual(1, z)
+
def test_simple_for(self):
def test_fn(l):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py
index 230e4cc0f3..ad97fdfa8e 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py
@@ -135,8 +135,7 @@ class CfgBuilder(gast.NodeVisitor):
# Handle the body
self.visit_statements(node.body)
body_exit = self.current_leaves
- self.current_leaves = []
- self.current_leaves.append(test)
+ self.current_leaves = [test]
# Handle the orelse
self.visit_statements(node.orelse)
self.current_leaves.extend(body_exit)
@@ -149,12 +148,15 @@ class CfgBuilder(gast.NodeVisitor):
self.continue_.append([])
# Handle the body
self.visit_statements(node.body)
+ body_exit = self.current_leaves
self.current_leaves.extend(self.continue_.pop())
self.set_current_leaves(test)
# Handle the orelse
self.visit_statements(node.orelse)
# The break statements and the test go to the next node
self.current_leaves.extend(self.break_.pop())
+ # Body and orelse statements can reach out of the loop
+ self.current_leaves.extend(body_exit)
def visit_For(self, node):
iter_ = CfgNode(node.iter)
@@ -162,9 +164,15 @@ class CfgBuilder(gast.NodeVisitor):
self.break_.append([])
self.continue_.append([])
self.visit_statements(node.body)
+ body_exit = self.current_leaves
self.current_leaves.extend(self.continue_.pop())
self.set_current_leaves(iter_)
+ # Handle the orelse
+ self.visit_statements(node.orelse)
+ # The break statements and the test go to the next node
self.current_leaves.extend(self.break_.pop())
+ # Body and orelse statements can reach out of the loop
+ self.current_leaves.extend(body_exit)
def visit_Break(self, node):
self.break_[-1].extend(self.current_leaves)
@@ -395,7 +403,13 @@ class Liveness(Backward):
super(Liveness, self).__init__('live', context)
def get_gen_kill(self, node, _):
+ # A variable's parents are live if it is live
+ # e.g. x is live if x.y is live. This means gen needs to return
+ # all parents of a variable (if it's an Attribute or Subscript).
+ # This doesn't apply to kill (e.g. del x.y doesn't affect liveness of x)
gen = activity.get_read(node.value, self.context)
+ gen = functools.reduce(lambda left, right: left | right.support_set, gen,
+ gen)
kill = activity.get_updated(node.value, self.context)
return gen, kill
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py
index af7eaf30e8..8d723ce09d 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py
@@ -247,6 +247,47 @@ class CFGTest(test.TestCase):
anno.getanno(body[2], 'defined_in'),
frozenset(map(qual_names.QN, ('x', 'g'))))
+ def test_loop_else(self):
+
+ # Disabling useless-else-on-loop error, because 'break' and 'continue'
+ # canonicalization are a separate analysis pass, and here we test
+ # the CFG analysis in isolation.
+ def for_orelse(x):
+ y = 0
+ for i in range(len(x)):
+ x += i
+ else: # pylint: disable=useless-else-on-loop
+ y = 1
+ return x, y
+
+ def while_orelse(x, i):
+ y = 0
+ while x < 10:
+ x += i
+ else: # pylint: disable=useless-else-on-loop
+ y = 1
+ return x, y
+
+ for f in (for_orelse, while_orelse):
+ node, ctx = self._parse_and_analyze(f, {})
+ cfg.run_analyses(node, cfg.ReachingDefinitions(ctx))
+ body = node.body[0].body
+ return_node = body[-1]
+ reaching_defs = anno.getanno(return_node, 'definitions_in')
+
+ # Y could be defined by Assign(Num(0)) or Assign(Num(1))
+ # X could be defined as an argument or an AugAssign.
+ y_defs = [node for var, node in reaching_defs if str(var) == 'y']
+ x_defs = [node for var, node in reaching_defs if str(var) == 'x']
+
+ self.assertEqual(set((gast.Assign,)), set(type(def_) for def_ in y_defs))
+ self.assertEqual(set((0, 1)), set(def_.value.n for def_ in y_defs))
+ self.assertEqual(len(y_defs), 2)
+ self.assertEqual(
+ set((gast.arguments, gast.AugAssign)),
+ set(type(def_) for def_ in x_defs))
+ self.assertEqual(len(x_defs), 2)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 44a8ffaf4b..04e32267cc 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -422,6 +422,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
GradientStats(*gradients_t, *hessians_t, bucket_idx);
}
present_gradient_stats *= normalizer_ratio;
+ GradientStats not_present =
+ root_gradient_stats - present_gradient_stats;
+ // If there was (almost) no sparsity, fix the default direction to LEFT.
+ bool fixed_default_direction = not_present.IsAlmostZero();
GradientStats left_gradient_stats;
for (int64 element_idx = start_index; element_idx < end_index;
@@ -441,6 +445,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
// backward pass gradients.
GradientStats right_gradient_stats =
present_gradient_stats - left_gradient_stats;
+
{
NodeStats left_stats_default_left =
ComputeNodeStats(root_gradient_stats - right_gradient_stats);
@@ -457,7 +462,9 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
best_dimension_idx = dimension_id;
}
}
- {
+ // Consider calculating the default direction only when there were
+ // enough missing examples.
+ if (!fixed_default_direction) {
NodeStats left_stats_default_right =
ComputeNodeStats(left_gradient_stats);
NodeStats right_stats_default_right =
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
index 28834ef55b..5cd37ec67e 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import random
+
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.proto import split_info_pb2
from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops
@@ -399,6 +401,65 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.6, split_node.split.threshold)
+ def testMakeSparseSplitDefaultDirectionIsStable(self):
+ """Tests default direction is stable when no sparsity."""
+ random.seed(1123)
+ for _ in range(50):
+ with self.test_session() as sess:
+ grad = random.random()
+ hessian = random.random()
+ # The data looks like the following (divide by the num of steps 2).
+ # Gradients | Partition | bucket ID |
+ # (grad, hessian) | 0 | -1 |
+ # And then 100 buckets of
+ # (grad/100, hessian/100), so there is no sparsity.
+ n_buckets = 100
+
+ # 1 for the overall sum, and 100 buckets.
+ partition_ids = array_ops.constant(
+ [0] * (n_buckets + 1), dtype=dtypes.int32)
+ # We have only 1 dimension in our sparse feature column.
+
+ bucket_ids = [-1] + [n for n in range(100)]
+ bucket_ids = array_ops.constant(bucket_ids, dtype=dtypes.int64)
+ dimension_ids = array_ops.constant(
+ [0] * (n_buckets + 1), dtype=dtypes.int64)
+ bucket_ids = array_ops.stack([bucket_ids, dimension_ids], axis=1)
+
+ gradients = [grad] + [grad / n_buckets] * n_buckets
+ gradients = array_ops.constant(gradients)
+ hessians = [hessian] + [hessian / n_buckets] * n_buckets
+ hessians = array_ops.constant(hessians)
+
+ boundaries = [x * 1 for x in range(n_buckets + 1)]
+ bucket_boundaries = array_ops.constant(boundaries, dtype=dtypes.float32)
+
+ partitions, gains, splits = (
+ split_handler_ops.build_sparse_inequality_splits(
+ num_minibatches=2,
+ partition_ids=partition_ids,
+ bucket_ids=bucket_ids,
+ gradients=gradients,
+ hessians=hessians,
+ bucket_boundaries=bucket_boundaries,
+ l1_regularization=0,
+ l2_regularization=2,
+ tree_complexity_regularization=0,
+ min_node_weight=0,
+ feature_column_group_id=0,
+ bias_feature_id=-1,
+ class_id=-1,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ partitions, gains, splits = (sess.run([partitions, gains, splits]))
+ self.assertAllEqual([0], partitions)
+ self.assertEqual(1, len(splits))
+
+ split_info = split_info_pb2.SplitInfo()
+ split_info.ParseFromString(splits[0])
+ self.assertTrue(
+ split_info.split_node.HasField(
+ 'sparse_float_binary_split_default_left'))
+
def testMakeMulticlassSparseSplit(self):
"""Tests split handler op."""
with self.test_session() as sess:
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index e529b25b3c..c5f7072aea 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -14,22 +14,27 @@
# ==============================================================================
"""Tools for working with object-based checkpoints.
-
-For creating and managing dependencies:
-@@CheckpointableObjectGraph
+Visualization and inspection:
@@dot_graph_from_checkpoint
@@object_metadata
+
+Creating and managing dependencies:
+@@Checkpointable
+@@CheckpointableObjectGraph
@@NoDependency
@@split_dependency
+@@UniqueNameTracker
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
+from tensorflow.python.training.checkpointable import Checkpointable
from tensorflow.python.training.checkpointable import NoDependency
from tensorflow.python.training.checkpointable_utils import object_metadata
diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD
index a5681ffa61..cbb9852ccf 100644
--- a/tensorflow/contrib/checkpoint/python/BUILD
+++ b/tensorflow/contrib/checkpoint/python/BUILD
@@ -8,12 +8,35 @@ py_library(
name = "checkpoint",
srcs_version = "PY2AND3",
deps = [
+ ":containers",
":split_dependency",
":visualize",
],
)
py_library(
+ name = "containers",
+ srcs = ["containers.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = ["//tensorflow/python:checkpointable"],
+)
+
+py_test(
+ name = "containers_test",
+ srcs = ["containers_test.py"],
+ deps = [
+ ":containers",
+ "//tensorflow/python:checkpointable",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:training",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "split_dependency",
srcs = ["split_dependency.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py
new file mode 100644
index 0000000000..82aa04e38f
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/containers.py
@@ -0,0 +1,77 @@
+"""Checkpointable data structures."""
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.training import checkpointable as checkpointable_lib
+
+
+class UniqueNameTracker(checkpointable_lib.CheckpointableBase):
+ """Adds dependencies on checkpointable objects with name hints.
+
+ Useful for creating dependencies with locally unique names.
+
+ Example usage:
+ ```python
+ class SlotManager(tf.contrib.checkpoint.Checkpointable):
+
+ def __init__(self):
+ # Create a dependency named "slotdeps" on the container.
+ self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker()
+ slotdeps = self.slotdeps
+ slots = []
+ slots.append(slotdeps.track(tfe.Variable(3.), "x")) # Named "x"
+ slots.append(slotdeps.track(tfe.Variable(4.), "y"))
+ slots.append(slotdeps.track(tfe.Variable(5.), "x")) # Named "x_1"
+ ```
+ """
+
+ def __init__(self):
+ self._maybe_initialize_checkpointable()
+ self._name_counts = {}
+
+ def track(self, checkpointable, base_name):
+ """Add a dependency on `checkpointable`.
+
+ Args:
+ checkpointable: An object to add a checkpoint dependency on.
+ base_name: A name hint, which is uniquified to determine the dependency
+ name.
+ Returns:
+ `checkpointable`, for chaining.
+ Raises:
+ ValueError: If `checkpointable` is not a checkpointable object.
+ """
+
+ if not isinstance(checkpointable, checkpointable_lib.CheckpointableBase):
+ raise ValueError(
+ ("Expected a checkpointable value, got %s which does not inherit "
+ "from CheckpointableBase.") % (checkpointable,))
+
+ def _format_name(prefix, number):
+ if number > 0:
+ return "%s_%d" % (prefix, number)
+ else:
+ return prefix
+
+ count = self._name_counts.get(base_name, 0)
+ candidate = _format_name(base_name, count)
+ while self._lookup_dependency(candidate) is not None:
+ count += 1
+ candidate = _format_name(base_name, count)
+ self._name_counts[base_name] = count + 1
+ return self._track_checkpointable(checkpointable, name=candidate)
diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py
new file mode 100644
index 0000000000..15775f4cb3
--- /dev/null
+++ b/tensorflow/contrib/checkpoint/python/containers_test.py
@@ -0,0 +1,100 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import six
+
+from tensorflow.contrib.checkpoint.python import containers
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpointable
+from tensorflow.python.training import checkpointable_utils
+from tensorflow.python.training.checkpointable_utils import object_metadata
+
+
+class UniqueNameTrackerTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testNames(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+ x1 = resource_variable_ops.ResourceVariable(2.)
+ x2 = resource_variable_ops.ResourceVariable(3.)
+ x3 = resource_variable_ops.ResourceVariable(4.)
+ y = resource_variable_ops.ResourceVariable(5.)
+ slots = containers.UniqueNameTracker()
+ slots.track(x1, "x")
+ slots.track(x2, "x")
+ slots.track(x3, "x_1")
+ slots.track(y, "y")
+ self.evaluate((x1.initializer, x2.initializer, x3.initializer,
+ y.initializer))
+ save_root = checkpointable_utils.Checkpoint(slots=slots)
+ save_path = save_root.save(checkpoint_prefix)
+
+ restore_slots = checkpointable.Checkpointable()
+ restore_root = checkpointable_utils.Checkpoint(
+ slots=restore_slots)
+ status = restore_root.restore(save_path)
+ restore_slots.x = resource_variable_ops.ResourceVariable(0.)
+ restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.)
+ restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.)
+ restore_slots.y = resource_variable_ops.ResourceVariable(0.)
+ status.assert_consumed().run_restore_ops()
+ self.assertEqual(2., self.evaluate(restore_slots.x))
+ self.assertEqual(3., self.evaluate(restore_slots.x_1))
+ self.assertEqual(4., self.evaluate(restore_slots.x_1_1))
+ self.assertEqual(5., self.evaluate(restore_slots.y))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testExample(self):
+ class SlotManager(checkpointable.Checkpointable):
+
+ def __init__(self):
+ self.slotdeps = containers.UniqueNameTracker()
+ slotdeps = self.slotdeps
+ slots = []
+ slots.append(slotdeps.track(
+ resource_variable_ops.ResourceVariable(3.), "x"))
+ slots.append(slotdeps.track(
+ resource_variable_ops.ResourceVariable(4.), "y"))
+ slots.append(slotdeps.track(
+ resource_variable_ops.ResourceVariable(5.), "x"))
+ self.slots = slots
+
+ manager = SlotManager()
+ self.evaluate([v.initializer for v in manager.slots])
+ checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager)
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ save_path = checkpoint.save(checkpoint_prefix)
+ metadata = object_metadata(save_path)
+ dependency_names = []
+ for node in metadata.nodes:
+ for child in node.children:
+ dependency_names.append(child.local_name)
+ six.assertCountEqual(
+ self,
+ dependency_names,
+ ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"])
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 1403483d28..8ede28602f 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -36,6 +36,7 @@ except ImportError:
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
+_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
class TPUClusterResolver(ClusterResolver):
@@ -70,6 +71,12 @@ class TPUClusterResolver(ClusterResolver):
def _gkeMaster():
return os.environ[_GKE_ENV_VARIABLE].split(',')[0]
+ @staticmethod
+ def _envVarFallback():
+ if _DEFAULT_ENV_VARIABLE in os.environ:
+ return os.environ[_DEFAULT_ENV_VARIABLE]
+ return None
+
def __init__(self,
tpu=None,
zone=None,
@@ -123,8 +130,11 @@ class TPUClusterResolver(ClusterResolver):
in_gke = self._inGke()
# When using GKE with Cloud TPUs, the env variable will be set.
- if tpu is None and in_gke:
- tpu = self._gkeMaster()
+ if tpu is None:
+ if in_gke:
+ tpu = self._gkeMaster()
+ else:
+ tpu = self._envVarFallback()
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
self._job_name = job_name
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index 6468bed497..f5a2f91271 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -54,7 +54,6 @@ tensorflow/python/keras/datasets/fashion_mnist
tensorflow/python/keras/datasets/imdb
tensorflow/python/keras/datasets/mnist
tensorflow/python/keras/datasets/reuters
-tensorflow/python/keras/estimator
tensorflow/python/keras/initializers
tensorflow/python/keras/layers
tensorflow/python/keras/losses
@@ -333,6 +332,8 @@ tensorflow/contrib/metrics
tensorflow/contrib/metrics/python
tensorflow/contrib/metrics/python/metrics
tensorflow/contrib/metrics/python/ops
+tensorflow/contrib/mixed_precision
+tensorflow/contrib/mixed_precision/python
tensorflow/contrib/mpi_collectives/python
tensorflow/contrib/mpi_collectives/python/ops
tensorflow/contrib/model_pruning
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index 1505d3e208..2d76bf530a 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -68,6 +68,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc"
"${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc"
+ "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc"
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index c4bdb69d82..8d24a7ae38 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -244,13 +244,11 @@ add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD
# tf_python_op_gen_main library
########################################################
set(tf_python_op_gen_main_srcs
- "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.h"
- "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.cc"
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc"
- "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc"
- "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc"
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h"
+ "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.cc"
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h"
+ "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc"
)
add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs})
@@ -464,12 +462,12 @@ set (pywrap_tensorflow_internal_src
"${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe_src.cc"
"${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.h"
"${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.cc"
- "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.h"
- "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.cc"
"${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.h"
"${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.cc"
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h"
"${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc"
+ "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h"
+ "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.h"
"${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.h"
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 077cbba9d2..a25aa85251 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -23,6 +23,8 @@ removing existing functionality.
See the @{$datasets$Importing Data} Programmer's Guide for an overview.
@@Counter
+@@CheckpointInputPipelineHook
+@@CsvDataset
@@SqlDataset
@@assert_element_shape
@@ -72,8 +74,10 @@ from tensorflow.contrib.data.python.ops.grouping import group_by_window
from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave
from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
+from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
+from tensorflow.contrib.data.python.ops.readers import CsvDataset
from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset
from tensorflow.contrib.data.python.ops.readers import make_csv_dataset
from tensorflow.contrib.data.python.ops.readers import read_batch_features
diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD
index c56910c783..7b69e10441 100644
--- a/tensorflow/contrib/data/kernels/BUILD
+++ b/tensorflow/contrib/data/kernels/BUILD
@@ -30,6 +30,16 @@ cc_library(
)
cc_library(
+ name = "csv_dataset_op",
+ srcs = ["csv_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+cc_library(
name = "ignore_errors_dataset_op",
srcs = ["ignore_errors_dataset_op.cc"],
deps = [
@@ -63,6 +73,7 @@ cc_library(
cc_library(
name = "dataset_kernels",
deps = [
+ ":csv_dataset_op",
":directed_interleave_dataset_op",
":ignore_errors_dataset_op",
":prefetching_kernels",
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
new file mode 100644
index 0000000000..76e54a284e
--- /dev/null
+++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
@@ -0,0 +1,508 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/parsing_ops.cc.
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow/core/lib/io/random_inputstream.h"
+
+namespace tensorflow {
+namespace {
+
+class CSVDatasetOp : public DatasetOpKernel {
+ public:
+ explicit CSVDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ const Tensor* filenames_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
+ OP_REQUIRES(
+ ctx, filenames_tensor->dims() <= 1,
+ errors::InvalidArgument("`filenames` must be a scalar or a vector."));
+
+ OpInputList record_defaults_list;
+ OP_REQUIRES_OK(ctx,
+ ctx->input_list("record_defaults", &record_defaults_list));
+ for (int i = 0; i < record_defaults_list.size(); ++i) {
+ OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2,
+ errors::InvalidArgument(
+ "There should only be 1 default per field but field ", i,
+ " has ", record_defaults_list[i].NumElements()));
+ }
+
+ const Tensor* select_cols_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("select_cols", &select_cols_tensor));
+ OP_REQUIRES(ctx, select_cols_tensor->dims() == 1,
+ errors::InvalidArgument("`select_cols` must be a vector."));
+
+ int64 buffer_size;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
+ OP_REQUIRES(ctx, buffer_size > 0,
+ errors::InvalidArgument("buffer_size should be positive"));
+
+ string delim;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "field_delim", &delim));
+ OP_REQUIRES(ctx, delim.size() == 1,
+ errors::InvalidArgument("field_delim should be only 1 char"));
+
+ bool header;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "header", &header));
+
+ bool use_quote_delim;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "use_quote_delim",
+ &use_quote_delim));
+ string na_value;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "na_value", &na_value));
+
+ std::vector<Tensor> record_defaults;
+ record_defaults.reserve(record_defaults_list.size());
+ for (const Tensor& t : record_defaults_list) {
+ record_defaults.push_back(t);
+ }
+
+ std::vector<string> filenames;
+ filenames.reserve(filenames_tensor->NumElements());
+ for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
+ filenames.push_back(filenames_tensor->flat<string>()(i));
+ }
+
+ std::vector<int64> select_cols;
+ select_cols.reserve(select_cols_tensor->NumElements());
+ for (int i = 0; i < select_cols_tensor->NumElements(); ++i) {
+ select_cols.push_back(select_cols_tensor->flat<int64>()(i));
+ }
+ OP_REQUIRES(
+ ctx, output_types_.size() == select_cols.size() || select_cols.empty(),
+ errors::InvalidArgument("select_cols should match output size"));
+ for (int i = 1; i < select_cols.size(); i++) {
+ OP_REQUIRES(ctx, select_cols[i - 1] < select_cols[i],
+ errors::InvalidArgument(
+ "select_cols should be strictly increasing indices"));
+ }
+ OP_REQUIRES(
+ ctx, select_cols.empty() || select_cols.front() >= 0,
+ errors::InvalidArgument("select_cols should be non-negative indices"));
+ bool select_all_cols = select_cols.empty();
+
+ *output = new Dataset(
+ ctx, std::move(filenames), header, buffer_size, output_types_,
+ output_shapes_, std::move(record_defaults), std::move(select_cols),
+ select_all_cols, use_quote_delim, delim[0], std::move(na_value));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, std::vector<string> filenames, bool header,
+ int64 buffer_size, const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes,
+ std::vector<Tensor> record_defaults, std::vector<int64> select_cols,
+ bool select_all_cols, bool use_quote_delim, char delim,
+ string na_value)
+ : GraphDatasetBase(ctx),
+ filenames_(std::move(filenames)),
+ header_(header),
+ buffer_size_(buffer_size),
+ out_type_(output_types),
+ output_shapes_(output_shapes),
+ record_defaults_(std::move(record_defaults)),
+ select_cols_(std::move(select_cols)),
+ select_all_cols_(select_all_cols),
+ use_quote_delim_(use_quote_delim),
+ delim_(delim),
+ na_value_(std::move(na_value)) {}
+
+ std::unique_ptr<IteratorBase> MakeIterator(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::CSV")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override { return out_type_; }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() override { return "CSVDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ // TODO(rachelim): Implement this
+ std::vector<Node*> input_tensors;
+ TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output));
+ return errors::Unimplemented("CSVDataset: AsGraphDefInternal");
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ do {
+ // We are currently processing a file, so try to read the next record
+ if (buffered_input_stream_) {
+ Status s = ReadRecord(ctx, out_tensors);
+ if (s.ok() || !errors::IsOutOfRange(s)) {
+ // Not at the end of file, return OK or non-EOF errors to caller.
+ *end_of_sequence = false;
+ return s;
+ }
+ // We have reached the end of the current file, so maybe
+ // move on to next file.
+ ResetStreamsLocked();
+ ++current_file_index_;
+ }
+ // Iteration ends when there are no more files to process.
+ if (current_file_index_ == dataset()->filenames_.size()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+ } while (true);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ // TODO(rachelim): Implement save
+ return errors::Unimplemented("CSVDataset: SaveInternal");
+ }
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ // TODO(rachelim): Implement restore
+ return errors::Unimplemented("CSVDataset: RestoreInternal");
+ }
+
+ private:
+ // Reads a record by parsing the input buffer, and converting extracted
+ // fields to output tensors as we go.
+ Status ReadRecord(IteratorContext* ctx, std::vector<Tensor>* out_tensors)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ // Extracts fields from line(s) from the buffered input stream.
+ out_tensors->reserve(dataset()->record_defaults_.size());
+
+ string input;
+ TF_RETURN_IF_ERROR(buffered_input_stream_->ReadLine(&input));
+
+ size_t current_idx = 0;
+ size_t num_fields_parsed = 0;
+ size_t selector_idx = 0; // Keep track of index into select_cols
+
+ while (current_idx < input.size()) {
+ // In each iteration, parse one field
+ if (input[current_idx] == '\n' || input[current_idx] == '\r') {
+ // This should never happen, because buffered input reader splits
+ // input on newlines.
+ return errors::InvalidArgument("Parsing error.");
+ }
+
+ bool quoted = false;
+ bool include =
+ (dataset()->select_all_cols_ ||
+ dataset()->select_cols_[selector_idx] == num_fields_parsed);
+
+ if (dataset()->use_quote_delim_ && input[current_idx] == '"') {
+ quoted = true;
+ current_idx++;
+ }
+
+ // Parse the body of the field
+ string field;
+ if (!quoted) {
+ while (current_idx < input.size() &&
+ input[current_idx] != dataset()->delim_) {
+ if ((dataset()->use_quote_delim_ && input[current_idx] == '"') ||
+ input[current_idx] == '\n' || input[current_idx] == '\r') {
+ return errors::InvalidArgument(
+ "Unquoted fields cannot have quotes/CRLFs inside");
+ }
+ if (include) field += input[current_idx];
+ current_idx++;
+ } // Exit condition: end of input, or current index at delim
+
+ // Go to next field or the end
+ current_idx++;
+ } else {
+ // Quoted field needs to be ended with '"' and delim or end
+ while (true) {
+ if (current_idx >= input.size() - 1 || input.empty()) {
+ if (current_idx == input.size() - 1 &&
+ input[current_idx] == '"') {
+ // We're at the end of the input, and the quote terminates the
+ // record. Go to end.
+ current_idx++;
+ break;
+ }
+ // If there's no terminating quote, it means our buffered record
+ // line reader split a record up. This can happen if there is a
+ // newline encased in quotes. The next line is also part of the
+ // record, so we read it and reset the index.
+ if (include && current_idx == input.size() - 1) {
+ // TODO(rachelim): Instead of building up a string, keep track
+ // of terminal indices (or starting char* and length)
+ // Also look into using /lib/strings/Scanner
+ field += input[current_idx];
+ }
+ if (include) {
+ field += '\n';
+ }
+ current_idx = 0;
+ Status s = buffered_input_stream_->ReadLine(&input);
+ if (!s.ok()) {
+ return errors::InvalidArgument(
+ "Quoted field has to end with quote followed by delim, "
+ "CRLF, or EOF");
+ }
+ } else if (input[current_idx] == '"' &&
+ input[current_idx + 1] == dataset()->delim_) {
+ // End of field, go to next field or end
+ current_idx += 2;
+ break;
+ } else if (input[current_idx] == '"') {
+ // Current char is a quote. Since we're not at end of field,
+ // the next character must also be a quote.
+ if (input[current_idx + 1] != '"') {
+ return errors::InvalidArgument(
+ "Quote inside a string has to be escaped by another "
+ "quote");
+ }
+ if (include) field += '"';
+ current_idx += 2;
+ } else {
+ if (include) field += input[current_idx];
+ current_idx++;
+ }
+ }
+ }
+
+ num_fields_parsed++;
+
+ if (include) {
+ // Add the tensor to the result
+ TF_RETURN_IF_ERROR(FieldToOutput(ctx, std::move(field),
+ selector_idx, out_tensors));
+ selector_idx++;
+ // Terminate early if we have all the fields we want
+ if (selector_idx == dataset()->select_cols_.size())
+ return Status::OK();
+ }
+ } // Exit condition: current_idx has reached the end of record
+
+ // Check if the last field is empty, and include it if necessary
+ bool include =
+ (dataset()->select_all_cols_ ||
+ dataset()->select_cols_[selector_idx] == num_fields_parsed);
+ if (include && !input.empty() &&
+ input[input.size() - 1] == dataset()->delim_) {
+ TF_RETURN_IF_ERROR(
+ FieldToOutput(ctx, string(), selector_idx, out_tensors));
+ }
+
+ // Check that number of fields matches
+ if (out_tensors->size() != dataset()->out_type_.size()) {
+ return errors::InvalidArgument("Expect ", dataset()->out_type_.size(),
+ " fields but have ",
+ out_tensors->size(), " in record");
+ }
+ return Status::OK();
+ }
+
+ // Given a string field, and its index in the output,
+ // converts it to a Tensor of the right type and adds it to the
+ // out_tensors vector.
+ Status FieldToOutput(IteratorContext* ctx, string field,
+ size_t output_idx,
+ std::vector<Tensor>* out_tensors) {
+ if (output_idx >= dataset()->out_type_.size()) {
+ // We can get here if we're selecting all columns, but the number of
+ // fields exceeds the number of defaults provided
+ return errors::InvalidArgument("Expect ", dataset()->out_type_.size(),
+ " fields but have more in record");
+ }
+ const DataType& dtype = dataset()->out_type_[output_idx];
+ Tensor component(ctx->allocator({}), dtype, {});
+ if ((field.empty() || field == dataset()->na_value_) &&
+ dataset()->record_defaults_[output_idx].NumElements() != 1) {
+ // If the field is empty or NA value, and default is not given,
+ // report error.
+ return errors::InvalidArgument("Field ", output_idx,
+ " is required but missing in record!");
+ }
+
+ switch (dtype) {
+ // For each case, if the field is empty, we use the default.
+ // Otherwise, we convert it to the right type.
+ case DT_INT32: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<int32>()() =
+ dataset()->record_defaults_[output_idx].flat<int32>()(0);
+ } else {
+ int32 value;
+ if (!strings::safe_strto32(field, &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid int32: ", field);
+ }
+ component.scalar<int32>()() = value;
+ }
+ break;
+ }
+ case DT_INT64: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<int64>()() =
+ dataset()->record_defaults_[output_idx].flat<int64>()(0);
+ } else {
+ int64 value;
+ if (!strings::safe_strto64(field, &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid int64: ", field);
+ }
+ component.scalar<int64>()() = value;
+ }
+ break;
+ }
+ case DT_FLOAT: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<float>()() =
+ dataset()->record_defaults_[output_idx].flat<float>()(0);
+ } else {
+ float value;
+ if (!strings::safe_strtof(field.c_str(), &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid float: ", field);
+ }
+ component.scalar<float>()() = value;
+ }
+ break;
+ }
+ case DT_DOUBLE: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<double>()() =
+ dataset()->record_defaults_[output_idx].flat<double>()(0);
+ } else {
+ double value;
+ if (!strings::safe_strtod(field.c_str(), &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid double: ", field);
+ }
+ component.scalar<double>()() = value;
+ }
+ break;
+ }
+ case DT_STRING: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<string>()() =
+ dataset()->record_defaults_[output_idx].flat<string>()(0);
+ } else {
+ component.scalar<string>()() = std::move(field);
+ }
+ break;
+ }
+ default:
+ return errors::InvalidArgument("csv: data type ", dtype,
+ " not supported in field ",
+ output_idx);
+ }
+ out_tensors->push_back(std::move(component));
+ return Status::OK();
+ }
+
+ // Sets up reader streams to read from the file at `current_file_index_`.
+ Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (current_file_index_ >= dataset()->filenames_.size()) {
+ return errors::InvalidArgument(
+ "current_file_index_:", current_file_index_,
+ " >= filenames_.size():", dataset()->filenames_.size());
+ }
+
+ // Actually move on to next file.
+ TF_RETURN_IF_ERROR(env->NewRandomAccessFile(
+ dataset()->filenames_[current_file_index_], &file_));
+ input_stream_.reset(
+ new io::RandomAccessInputStream(file_.get(), false));
+ // TODO(rachelim): Maintain our own buffer so we don't read every record
+ // twice
+ buffered_input_stream_.reset(new io::BufferedInputStream(
+ input_stream_.get(), dataset()->buffer_size_, false));
+ if (dataset()->header_) {
+ // Ignore header line
+ string str;
+ Status s = buffered_input_stream_->ReadLine(&str);
+ if (errors::IsOutOfRange(s)) {
+ return errors::InvalidArgument("Can't read header of empty file");
+ }
+ }
+ return Status::OK();
+ }
+
+ // Resets all reader streams.
+ void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ input_stream_.reset();
+ buffered_input_stream_.reset();
+ file_.reset();
+ }
+
+ mutex mu_;
+ std::unique_ptr<io::RandomAccessInputStream> input_stream_
+ GUARDED_BY(mu_);
+ std::unique_ptr<io::BufferedInputStream> buffered_input_stream_
+ GUARDED_BY(mu_);
+ size_t current_file_index_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<RandomAccessFile> file_
+ GUARDED_BY(mu_); // must outlive input_stream_
+ }; // class Iterator
+
+ const std::vector<string> filenames_;
+ const bool header_;
+ const int64 buffer_size_;
+ const DataTypeVector out_type_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ const std::vector<Tensor> record_defaults_;
+ const std::vector<int64> select_cols_;
+ const bool select_all_cols_;
+ const bool use_quote_delim_;
+ const char delim_;
+ const string na_value_;
+ }; // class Dataset
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+}; // class CSVDatasetOp
+
+// Register the kernel implementation for CSVDataset.
+REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index 137deb6352..f271d269ab 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -34,6 +34,40 @@ data_input_datasets: `N` datasets with the same type that will be interleaved
according to the values of `selector_input_dataset`.
)doc");
+REGISTER_OP("CSVDataset")
+ .Input("filenames: string")
+ .Input("buffer_size: int64")
+ .Input("header: bool")
+ .Input("field_delim: string")
+ .Input("use_quote_delim: bool")
+ .Input("na_value: string")
+ .Input("select_cols: int64")
+ .Input("record_defaults: output_types")
+ .Output("handle: variant")
+ .Attr("output_types: list({float,double,int32,int64,string}) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // `filenames` must be a scalar or a vector.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
+ // `buffer_size`, `header`, `field_delim`, `use_quote_delim`,
+ // `na_value` must be scalars
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
+ // `select_cols` must be a vector
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &unused));
+ // `record_defaults` must be a list of scalars...?
+ for (size_t i = 7; i < c->num_inputs(); ++i) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused));
+ }
+ return shape_inference::ScalarShape(c);
+ });
+
REGISTER_OP("IgnoreErrorsDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 6017e27e73..26b80fcf31 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -11,7 +11,10 @@ py_test(
size = "medium",
srcs = ["batch_dataset_op_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ ],
deps = [
":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:batching",
@@ -118,6 +121,19 @@ py_library(
)
py_test(
+ name = "csv_dataset_op_test",
+ size = "small",
+ srcs = ["csv_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test",
+ "//tensorflow/contrib/data/python/ops:readers",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "filter_dataset_op_test",
size = "small",
srcs = ["filter_dataset_op_test.py"],
@@ -411,6 +427,7 @@ py_test(
srcs = ["sql_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:readers",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
new file mode 100644
index 0000000000..641a389c03
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -0,0 +1,378 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for CsvDatasetOp."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+import time
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.client import session
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_parsing_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+
+class CsvDatasetOpTest(test.TestCase):
+
+ def _assert_datasets_equal(self, g, ds1, ds2):
+ assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, '
+ '%s') % (ds1.output_shapes,
+ ds2.output_shapes)
+ assert ds1.output_types == ds2.output_types
+ assert ds1.output_classes == ds2.output_classes
+ next1 = ds1.make_one_shot_iterator().get_next()
+ next2 = ds2.make_one_shot_iterator().get_next()
+ with self.test_session(graph=g) as sess:
+ # Run through datasets and check that outputs match, or errors match.
+ while True:
+ try:
+ op1 = sess.run(next1)
+ except (errors.OutOfRangeError, ValueError) as e:
+ # If op1 throws an exception, check that op2 throws same exception.
+ with self.assertRaises(type(e)):
+ sess.run(next2)
+ break
+ op2 = sess.run(next2)
+ self.assertAllEqual(op1, op2)
+
+ def setup_files(self, inputs):
+ filenames = []
+ for i, ip in enumerate(inputs):
+ fn = os.path.join(self.get_temp_dir(), 'temp_%d.txt' % i)
+ with open(fn, 'w') as f:
+ f.write('\n'.join(ip))
+ filenames.append(fn)
+ return filenames
+
+ def _make_test_datasets(self, inputs, **kwargs):
+ # Test by comparing its output to what we could get with map->decode_csv
+ filenames = self.setup_files(inputs)
+ dataset_expected = core_readers.TextLineDataset(filenames)
+ dataset_expected = dataset_expected.map(
+ lambda l: gen_parsing_ops.decode_csv(l, **kwargs))
+ dataset_actual = readers.CsvDataset(filenames, **kwargs)
+ return (dataset_actual, dataset_expected)
+
+ def _test_by_comparison(self, inputs, **kwargs):
+ """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
+ with ops.Graph().as_default() as g:
+ dataset_actual, dataset_expected = self._make_test_datasets(
+ inputs, **kwargs)
+ self._assert_datasets_equal(g, dataset_actual, dataset_expected)
+
+ def _test_dataset(self,
+ inputs,
+ expected_output=None,
+ expected_err_re=None,
+ **kwargs):
+ """Checks that elements produced by CsvDataset match expected output."""
+ # Convert str type because py3 tf strings are bytestrings
+ filenames = self.setup_files(inputs)
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ dataset = readers.CsvDataset(filenames, **kwargs)
+ nxt = dataset.make_one_shot_iterator().get_next()
+ if expected_err_re is None:
+ # Verify that output is expected, without errors
+ expected_output = [[
+ v.encode('utf-8') if isinstance(v, str) else v for v in op
+ ] for op in expected_output]
+ for value in expected_output:
+ op = sess.run(nxt)
+ self.assertAllEqual(op, value)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(nxt)
+ else:
+ # Verify that OpError is produced as expected
+ with self.assertRaisesOpError(expected_err_re):
+ while True:
+ try:
+ sess.run(nxt)
+ except errors.OutOfRangeError:
+ break
+
+ def testCsvDataset_floatRequired(self):
+ record_defaults = [[]] * 4
+ inputs = [['1,2,3,4']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_int(self):
+ record_defaults = [[0]] * 4
+ inputs = [['1,2,3,4', '5,6,7,8']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_float(self):
+ record_defaults = [[0.0]] * 4
+ inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_string(self):
+ record_defaults = [['']] * 4
+ inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_withQuoted(self):
+ record_defaults = [['']] * 4
+ inputs = [['1.0,2.1,"hello, it is me",4.3', '5.4,6.5,goodbye,8.7']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_mixedTypes(self):
+ record_defaults = [
+ constant_op.constant([], dtype=dtypes.int32),
+ constant_op.constant([], dtype=dtypes.float32),
+ constant_op.constant([], dtype=dtypes.string),
+ constant_op.constant([], dtype=dtypes.float64)
+ ]
+ inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_withUseQuoteDelimFalse(self):
+ record_defaults = [['']] * 4
+ inputs = [['1,2,"3,4"', '"5,6",7,8']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, use_quote_delim=False)
+
+ def testCsvDataset_withFieldDelim(self):
+ record_defaults = [[0]] * 4
+ inputs = [['1:2:3:4', '5:6:7:8']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, field_delim=':')
+
+ def testCsvDataset_withEmptyValues(self):
+ record_defaults = [[0]] * 4
+ inputs = [['1,,3,4', ',6,7,8']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_withNaValue(self):
+ record_defaults = [[0]] * 4
+ inputs = [['1,NA,3,4', 'NA,6,7,8']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, na_value='NA')
+
+ def testCsvDataset_withSelectCols(self):
+ record_defaults = [[0]] * 2
+ inputs = [['1,2,3,4', '5,6,7,8']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, select_cols=[1, 2])
+
+ def testCsvDataset_withSelectColsTooHigh(self):
+ record_defaults = [[0]] * 2
+ inputs = [['1,2,3,4', '5,6,7,8']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Expect 2 fields but have 1 in record',
+ record_defaults=record_defaults,
+ select_cols=[3, 4])
+
+ def testCsvDataset_withMultipleFiles(self):
+ record_defaults = [[0]] * 4
+ inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_withNewLine(self):
+ # In this case, we expect it to behave differently from
+ # TextLineDataset->map(decode_csv) since that flow has bugs
+ record_defaults = [['']] * 4
+ inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']]
+ expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']]
+ self._test_dataset(inputs, expected, record_defaults=record_defaults)
+
+ def testCsvDataset_withMultipleNewLines(self):
+ # In this case, we expect it to behave differently from
+ # TextLineDataset->map(decode_csv) since that flow has bugs
+ record_defaults = [['']] * 4
+ inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']]
+ expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']]
+ self._test_dataset(inputs, expected, record_defaults=record_defaults)
+
+ def testCsvDataset_withLeadingAndTrailingSpaces(self):
+ record_defaults = [[0.0]] * 4
+ inputs = [['0, 1, 2, 3']]
+ expected = [[0.0, 1.0, 2.0, 3.0]]
+ self._test_dataset(inputs, expected, record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithMissingDefault(self):
+ record_defaults = [[]] * 2
+ inputs = [['0,']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Field 1 is required but missing in record!',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithFewerDefaultsThanFields(self):
+ record_defaults = [[0.0]] * 2
+ inputs = [['0,1,2,3']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Expect 2 fields but have more in record',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithMoreDefaultsThanFields(self):
+ record_defaults = [[0.0]] * 5
+ inputs = [['0,1,2,3']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Expect 5 fields but have 4 in record',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_withHeader(self):
+ record_defaults = [[0]] * 2
+ inputs = [['col1,col2', '1,2']]
+ expected = [[1, 2]]
+ self._test_dataset(
+ inputs,
+ expected,
+ record_defaults=record_defaults,
+ header=True,
+ )
+
+ def testCsvDataset_withHeaderAndNoRecords(self):
+ record_defaults = [[0]] * 2
+ inputs = [['col1,col2']]
+ expected = []
+ self._test_dataset(
+ inputs,
+ expected,
+ record_defaults=record_defaults,
+ header=True,
+ )
+
+ def testCsvDataset_errorWithHeaderEmptyFile(self):
+ record_defaults = [[0]] * 2
+ inputs = [[]]
+ self._test_dataset(
+ inputs,
+ expected_err_re="Can't read header of empty file",
+ record_defaults=record_defaults,
+ header=True,
+ )
+
+ def testCsvDataset_withEmptyFile(self):
+ record_defaults = [['']] * 2
+ inputs = [['']] # Empty file
+ self._test_dataset(
+ inputs, expected_output=[], record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithEmptyRecord(self):
+ record_defaults = [['']] * 2
+ inputs = [['', '1,2']] # First record is empty
+ self._test_dataset(
+ inputs,
+ expected_err_re='Expect 2 fields but have 0 in record',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_withChainedOps(self):
+ # Testing that one dataset can create multiple iterators fine.
+ # `repeat` creates multiple iterators from the same C++ Dataset.
+ record_defaults = [[0]] * 4
+ inputs = [['1,,3,4', '5,6,,8']]
+ ds_actual, ds_expected = self._make_test_datasets(
+ inputs, record_defaults=record_defaults)
+ with ops.Graph().as_default() as g:
+ self._assert_datasets_equal(g,
+ ds_actual.repeat(5).prefetch(1),
+ ds_expected.repeat(5).prefetch(1))
+
+ def testCsvDataset_withTypeDefaults(self):
+ # Testing using dtypes as record_defaults for required fields
+ record_defaults = [dtypes.float32, dtypes.float32]
+ inputs = [['1.0,2.0', '3.0,4.0']]
+ self._test_dataset(
+ inputs,
+ [[1.0, 2.0], [3.0, 4.0]],
+ record_defaults=record_defaults,
+ )
+
+
+class CsvDatasetBenchmark(test.Benchmark):
+ """Benchmarks for the various ways of creating a dataset from CSV files.
+ """
+
+ def _setUp(self):
+ # Since this isn't test.TestCase, have to manually create a test dir
+ gfile.MakeDirs(googletest.GetTempDir())
+ self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
+
+ self._num_cols = [4, 64, 256]
+ self._batch_size = 500
+ self._filenames = []
+ for n in self._num_cols:
+ fn = os.path.join(self._temp_dir, 'file%d.csv' % n)
+ with open(fn, 'w') as f:
+ # Just write 10 rows and use `repeat`...
+ row = ','.join(['1.23456E12' for _ in range(n)])
+ f.write('\n'.join([row for _ in range(10)]))
+ self._filenames.append(fn)
+
+ def _tearDown(self):
+ gfile.DeleteRecursively(self._temp_dir)
+
+ def _runBenchmark(self, dataset, num_cols, prefix):
+ next_element = dataset.make_one_shot_iterator().get_next()
+ with session.Session() as sess:
+ for _ in range(5):
+ sess.run(next_element)
+ deltas = []
+ for _ in range(10):
+ start = time.time()
+ sess.run(next_element)
+ end = time.time()
+ deltas.append(end - start)
+ median_wall_time = np.median(deltas) / 100
+ print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols,
+ median_wall_time))
+ self.report_benchmark(
+ iters=self._batch_size,
+ wall_time=median_wall_time,
+ name='%s_with_cols_%d' % (prefix, num_cols))
+
+ def benchmarkBatchThenMap(self):
+ self._setUp()
+ for i in range(len(self._filenames)):
+ num_cols = self._num_cols[i]
+ kwargs = {'record_defaults': [[0.0]] * num_cols}
+ dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+ dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
+ dataset = dataset.batch(self._batch_size)
+ self._runBenchmark(dataset, num_cols, 'csv_map_then_batch')
+ self._tearDown()
+
+ def benchmarkCsvDataset(self):
+ self._setUp()
+ for i in range(len(self._filenames)):
+ num_cols = self._num_cols[i]
+ kwargs = {'record_defaults': [[0.0]] * num_cols}
+ dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+ dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
+ dataset = dataset.batch(self._batch_size)
+ self._runBenchmark(dataset, num_cols, 'csv_fused_dataset')
+ self._tearDown()
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
index e26cef8ec5..4148addf28 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
@@ -22,6 +22,7 @@ import os
import sqlite3
+from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class SqlDatasetTest(test.TestCase):
+class SqlDatasetTestBase(test.TestCase):
def _createSqlDataset(self, output_types, num_repeats=1):
dataset = readers.SqlDataset(self.driver_name, self.data_source_name,
@@ -92,6 +93,9 @@ class SqlDatasetTest(test.TestCase):
conn.commit()
conn.close()
+
+class SqlDatasetTest(SqlDatasetTestBase):
+
# Test that SqlDataset can read from a database table.
def testReadResultSet(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
@@ -652,5 +656,27 @@ class SqlDatasetTest(test.TestCase):
sess.run(get_next)
+class SqlDatasetSerializationTest(
+ SqlDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, num_repeats):
+ data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
+ driver_name = array_ops.placeholder_with_default(
+ array_ops.constant("sqlite", dtypes.string), shape=[])
+ query = ("SELECT first_name, last_name, motto FROM students ORDER BY "
+ "first_name DESC")
+ output_types = (dtypes.string, dtypes.string, dtypes.string)
+ return readers.SqlDataset(driver_name, data_source_name, query,
+ output_types).repeat(num_repeats)
+
+ def testSQLSaveable(self):
+ num_repeats = 4
+ num_outputs = num_repeats * 2
+ self.run_core_tests(lambda: self._build_dataset(num_repeats),
+ lambda: self._build_dataset(num_repeats // 2),
+ num_outputs)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 7a3e42cc72..eceecfd174 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -45,6 +45,27 @@ py_library(
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+)
+
+py_test(
+ name = "iterator_ops_test",
+ size = "small",
+ srcs = ["iterator_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":iterator_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:model_fn",
],
)
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py
index d736029fb0..f1d0e5cddc 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops.py
+++ b/tensorflow/contrib/data/python/ops/iterator_ops.py
@@ -16,10 +16,12 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
+from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.training import saver
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import session_run_hook
def make_saveable_from_iterator(iterator):
@@ -60,14 +62,14 @@ def make_saveable_from_iterator(iterator):
return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access
-class _Saveable(saver.BaseSaverBuilder.SaveableObject):
+class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
"""SaveableObject for saving/restoring iterator state."""
def __init__(self, iterator_resource):
serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
specs = [
- saver.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
- iterator_resource.name + "-state")
+ saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
+ iterator_resource.name + "-state")
]
super(_Saveable, self).__init__(iterator_resource, specs,
iterator_resource.name)
@@ -75,3 +77,160 @@ class _Saveable(saver.BaseSaverBuilder.SaveableObject):
def restore(self, restored_tensors, unused_restored_shapes):
with ops.colocate_with(self.op):
return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
+
+
+class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
+ """Checkpoints input pipeline state every N steps or seconds.
+
+ This hook saves the state of the iterators in the `Graph` so that when
+ training is resumed the input pipeline continues from where it left off.
+ This could potentially avoid overfitting in certain pipelines where the
+ number of training steps per eval are small compared to the dataset
+ size or if the training pipeline is pre-empted.
+
+ Differences from `CheckpointSaverHook`:
+ 1. Saves only the input pipelines in the "iterators" collection and not the
+ global variables or other saveable objects.
+ 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary.
+
+ Example of checkpointing the training pipeline:
+
+ ```python
+ est = tf.estimator.Estimator(model_fn)
+ while True:
+ est.train(
+ train_input_fn,
+ hooks=[tf.contrib.data.CheckpointInputPipelineHook(est)],
+ steps=train_steps_per_eval)
+ # Note: We do not pass the hook here.
+ metrics = est.evaluate(eval_input_fn)
+ if should_stop_the_training(metrics):
+ break
+ ```
+
+ This hook should be used if the input pipeline state needs to be saved
+ separate from the model checkpoint. Doing so may be useful for a few reasons:
+ 1. The input pipeline checkpoint may be large, if there are large shuffle
+ or prefetch buffers for instance, and may bloat the checkpoint size.
+ 2. If the input pipeline is shared between training and validation, restoring
+ the checkpoint during validation may override the validation input
+ pipeline.
+
+ For saving the input pipeline checkpoint alongside the model weights use
+ @{tf.contrib.data.make_saveable_from_iterator} directly to create a
+ `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however,
+ that you will need to be careful not to restore the training iterator during
+ eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS
+ collector when building the eval graph.
+ """
+
+ def __init__(self, estimator):
+ """Initializes a `CheckpointInputPipelineHook`.
+
+ Args:
+ estimator: Estimator.
+
+ Raises:
+ ValueError: One of `save_steps` or `save_secs` should be set.
+ ValueError: At most one of saver or scaffold should be set.
+ """
+ # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
+ # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
+ # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
+ # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
+ # to be different to avoid conflicts with the model checkpoint.
+
+ # pylint: disable=protected-access
+ checkpoint_prefix = "input"
+ if estimator._config.num_worker_replicas > 1:
+ # Distributed setting.
+ suffix = "_{}_{}".format(estimator._config.task_type,
+ estimator._config.task_id)
+ checkpoint_prefix += suffix
+ # pylint: enable=protected-access
+
+ # We use a composition paradigm instead of inheriting from
+ # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
+ # to check whether a `CheckpointSaverHook` is already present in the list
+ # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
+ # would thwart this behavior. This hook checkpoints *only the iterators*
+ # and not the graph variables.
+ self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
+ estimator.model_dir,
+ save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access
+ save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access
+ checkpoint_basename=checkpoint_prefix + ".ckpt")
+
+ # Name for the protocol buffer file that will contain the list of most
+ # recent checkpoints stored as a `CheckpointState` protocol buffer.
+ # This file, kept in the same directory as the checkpoint files, is
+ # automatically managed by the `Saver` to keep track of recent checkpoints.
+ # The default name used by the `Saver` for this file is "checkpoint". Here
+ # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
+ # `checkpoint_dir` is the same as the model checkpoint directory, there are
+ # no conflicts during restore.
+ self._latest_filename = "checkpoint_" + checkpoint_prefix
+
+ def begin(self):
+ # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
+ # collection if no `Saver` or `Scaffold` is provided.
+ # pylint: disable=protected-access
+ if (self._checkpoint_saver_hook._saver is None and
+ self._checkpoint_saver_hook._scaffold is None):
+ iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
+ saveables = [_Saveable(i) for i in iterators]
+ self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
+ self._latest_filename)
+ # pylint: enable=protected-access
+ self._checkpoint_saver_hook.begin()
+
+ def after_create_session(self, session, coord):
+ # Check if there is an existing checkpoint. If so, restore from it.
+ # pylint: disable=protected-access
+ latest_checkpoint_path = saver_lib.latest_checkpoint(
+ self._checkpoint_saver_hook._checkpoint_dir,
+ latest_filename=self._latest_filename)
+ if latest_checkpoint_path:
+ self._checkpoint_saver_hook._get_saver().restore(session,
+ latest_checkpoint_path)
+ else:
+ # The checkpoint saved here is the state at step "global_step".
+ # Note: We do not save the GraphDef or MetaGraphDef here.
+ global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
+ self._checkpoint_saver_hook._save(session, global_step)
+ self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
+ # pylint: enable=protected-access
+
+ def before_run(self, run_context):
+ return self._checkpoint_saver_hook.before_run(run_context)
+
+ def after_run(self, run_context, run_values):
+ self._checkpoint_saver_hook.after_run(run_context, run_values)
+
+ def end(self, session):
+ self._checkpoint_saver_hook.end(session)
+
+
+class _CustomSaver(saver_lib.Saver):
+ """`Saver` with a different default `latest_filename`.
+
+ This is used in the `CheckpointInputPipelineHook` to avoid conflicts with
+ the model ckpt saved by the `CheckpointSaverHook`.
+ """
+
+ def __init__(self, var_list, latest_filename):
+ super(_CustomSaver, self).__init__(var_list)
+ self._latest_filename = latest_filename
+
+ def save(self,
+ sess,
+ save_path,
+ global_step=None,
+ latest_filename=None,
+ meta_graph_suffix="meta",
+ write_meta_graph=True,
+ write_state=True,
+ strip_default_attrs=False):
+ return super(_CustomSaver, self).save(
+ sess, save_path, global_step, latest_filename or self._latest_filename,
+ meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/ops/iterator_ops_test.py
new file mode 100644
index 0000000000..30a993b1f7
--- /dev/null
+++ b/tensorflow/contrib/data/python/ops/iterator_ops_test.py
@@ -0,0 +1,123 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for experimental iterator_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import iterator_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import training_util
+
+
+class CheckpointInputPipelineHookTest(test.TestCase):
+
+ @staticmethod
+ def _model_fn(features, labels, mode, config):
+ del labels
+ del mode
+ del config
+ global_step = training_util.get_or_create_global_step()
+ update_global_step_op = global_step.assign_add(1)
+ latest_feature = variables.Variable(
+ 0, name='latest_feature', dtype=dtypes.int64)
+ store_latest_feature_op = latest_feature.assign(features)
+ ops.add_to_collection('my_vars', global_step)
+ ops.add_to_collection('my_vars', latest_feature)
+ return model_fn.EstimatorSpec(
+ mode='train',
+ train_op=control_flow_ops.group(
+ [update_global_step_op, store_latest_feature_op]),
+ loss=constant_op.constant(2.0))
+
+ def _read_vars(self, model_dir):
+ """Returns (global_step, latest_feature)."""
+ with ops.Graph().as_default() as g:
+ ckpt_path = saver_lib.latest_checkpoint(model_dir)
+ meta_filename = ckpt_path + '.meta'
+ saver_lib.import_meta_graph(meta_filename)
+ saver = saver_lib.Saver()
+ with self.test_session(graph=g) as sess:
+ saver.restore(sess, ckpt_path)
+ return sess.run(ops.get_collection('my_vars'))
+
+ def _build_iterator_saver_hook(self, est):
+ return iterator_ops.CheckpointInputPipelineHook(est)
+
+ def testReturnDatasetFromInputFn(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.range(10)
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+
+ def testBuildIteratorInInputFn(self):
+
+ def _input_fn():
+ ds = dataset_ops.Dataset.range(10)
+ iterator = ds.make_one_shot_iterator()
+ return iterator.get_next()
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+
+ def testDoNotRestore(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.range(10)
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+ # Hook not provided, input pipeline was not restored.
+ est.train(_input_fn, steps=2)
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1))
+
+ def testRaiseErrorIfNoIterator(self):
+
+ def _input_fn():
+ return constant_op.constant(1, dtype=dtypes.int64)
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ with self.assertRaises(ValueError):
+ est.train(
+ _input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index bbb808fbd7..11fc85d09e 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -23,10 +23,12 @@ from math import ceil
import numpy as np
from tensorflow.contrib.data.python.ops import batching
+from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.data.python.ops import shuffle_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -102,6 +104,7 @@ def _infer_type(str_val, na_value, prev_type, float_dtype):
def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header,
comment):
+ """Generator that yields rows of CSV file(s) in order."""
for fn in filenames:
with file_io.FileIO(fn, "r") as f:
rdr = csv.reader(
@@ -421,6 +424,146 @@ def make_csv_dataset(
return dataset
+_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB
+
+
+class CsvDataset(dataset_ops.Dataset):
+ """A Dataset comprising lines from one or more CSV files."""
+
+ def __init__(self,
+ filenames,
+ record_defaults,
+ buffer_size=None,
+ header=False,
+ field_delim=",",
+ use_quote_delim=True,
+ na_value="",
+ select_cols=None):
+ """Creates a `CsvDataset` by reading and decoding CSV files.
+
+ The elements of this dataset correspond to records from the file(s).
+ RFC 4180 format is expected for CSV files
+ (https://tools.ietf.org/html/rfc4180)
+ Note that we allow leading and trailing spaces with int or float field.
+
+
+ For example, suppose we have a file 'my_file0.csv' with four CSV columns of
+ different data types:
+ ```
+ abcdefg,4.28E10,5.55E6,12
+ hijklmn,-5.3E14,,2
+ ```
+
+ We can construct a CsvDataset from it as follows:
+ ```python
+ dataset = tf.contrib.data.CsvDataset(
+ "my_file*.csv",
+ [tf.float32, # Required field, use dtype or empty tensor
+ tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0
+ tf.int32, # Required field, use dtype or empty tensor
+ ],
+ select_cols=[1,2,3] # Only parse last three columns
+ )
+ ```
+
+ The expected output of its iterations is:
+ ```python
+ next = dataset.make_one_shot_iterator().get_next()
+ with tf.Session() as sess:
+ while True:
+ try:
+ print(sess.run(nxt))
+ except tf.errors.OutOfRangeError:
+ break
+
+ >> (4.28e10, 5.55e6, 12)
+ >> (-5.3e14, 0.0, 2)
+ ```
+
+ Args:
+ filenames: A `tf.string` tensor containing one or more filenames.
+ record_defaults: A list of default values for the CSV fields. Each item in
+ the list is either a valid CSV `DType` (float32, float64, int32, int64,
+ string), or a `Tensor` object with one of the above types. One per
+ column of CSV data, with either a scalar `Tensor` default value for the
+ column if it is optional, or `DType` or empty `Tensor` if required. If
+ both this and `select_columns` are specified, these must have the same
+ lengths, and `column_defaults` is assumed to be sorted in order of
+ increasing column index.
+ buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
+ to buffer while reading files. Defaults to 4MB.
+ header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
+ have header line(s) that should be skipped when parsing. Defaults to
+ `False`.
+ field_delim: (Optional.) A `tf.string` scalar containing the delimiter
+ character that separates fields in a record. Defaults to `","`.
+ use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
+ double quotation marks as regular characters inside of string fields
+ (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
+ na_value: (Optional.) A `tf.string` scalar indicating a value that will
+ be treated as NA/NaN.
+ select_cols: (Optional.) A sorted list of column indices to select from
+ the input data. If specified, only this subset of columns will be
+ parsed. Defaults to parsing all columns.
+ """
+ super(CsvDataset, self).__init__()
+ self._filenames = ops.convert_to_tensor(
+ filenames, dtype=dtypes.string, name="filenames")
+ record_defaults = [
+ constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
+ for x in record_defaults
+ ]
+ self._record_defaults = ops.convert_n_to_tensor(
+ record_defaults, name="record_defaults")
+ self._buffer_size = convert.optional_param_to_tensor(
+ "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
+ self._header = ops.convert_to_tensor(
+ header, dtype=dtypes.bool, name="header")
+ self._field_delim = ops.convert_to_tensor(
+ field_delim, dtype=dtypes.string, name="field_delim")
+ self._use_quote_delim = ops.convert_to_tensor(
+ use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
+ self._na_value = ops.convert_to_tensor(
+ na_value, dtype=dtypes.string, name="na_value")
+ self._select_cols = convert.optional_param_to_tensor(
+ "select_cols",
+ select_cols,
+ argument_default=[],
+ argument_dtype=dtypes.int64,
+ )
+ self._output_shapes = tuple(
+ tensor_shape.scalar() for _ in range(len(record_defaults)))
+ self._output_types = tuple(d.dtype for d in self._record_defaults)
+ self._output_classes = tuple(
+ ops.Tensor for _ in range(len(record_defaults)))
+
+ def _as_variant_tensor(self):
+ # Constructs graph node for the dataset op.
+ return contrib_gen_dataset_ops.csv_dataset(
+ filenames=self._filenames,
+ record_defaults=self._record_defaults,
+ buffer_size=self._buffer_size,
+ header=self._header,
+ output_shapes=self._output_shapes,
+ field_delim=self._field_delim,
+ use_quote_delim=self._use_quote_delim,
+ na_value=self._na_value,
+ select_cols=self._select_cols,
+ )
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
def make_batched_features_dataset(file_pattern,
batch_size,
features,
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index a1d56066b4..6192f04c8b 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -94,7 +94,7 @@ cuda_py_test(
cuda_py_test(
name = "distribution_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/distribution_test.py"],
additional_deps = [
":distributions_py",
@@ -337,7 +337,7 @@ cuda_py_test(
cuda_py_test(
name = "mvn_tril_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/mvn_tril_test.py"],
additional_deps = [
":distributions_py",
@@ -710,6 +710,7 @@ cuda_py_test(
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:client_testlib",
],
+ shard_count = 4,
tags = ["noasan"], # times out, http://b/78588814
)
diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py
index 88ed012784..d813831bef 100644
--- a/tensorflow/contrib/distributions/python/ops/autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py
@@ -144,7 +144,7 @@ class Autoregressive(distribution_lib.Distribution):
`distribution_fn(sample0).event_shape.num_elements()` are both `None`.
ValueError: if `num_steps < 1`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name) as name:
self._distribution_fn = distribution_fn
self._sample0 = sample0
diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
index bf5590cd55..8a4041cf43 100644
--- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
+from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -104,7 +105,7 @@ class BatchReshape(distribution_lib.Distribution):
ValueError: if `batch_shape` size is not the same as a
`distribution.batch_shape` size.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
name = name or "BatchReshape" + distribution.name
self._distribution = distribution
with ops.name_scope(name, values=[batch_shape]) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py
index 12d1603178..24b26bf124 100644
--- a/tensorflow/contrib/distributions/python/ops/binomial.py
+++ b/tensorflow/contrib/distributions/python/ops/binomial.py
@@ -163,7 +163,7 @@ class Binomial(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[total_count, logits, probs]) as name:
self._total_count = self._maybe_assert_valid_total_count(
ops.convert_to_tensor(total_count, name="total_count"),
diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py
index daacfe657f..f5ffdd8731 100644
--- a/tensorflow/contrib/distributions/python/ops/cauchy.py
+++ b/tensorflow/contrib/distributions/python/ops/cauchy.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
"Cauchy",
@@ -120,7 +121,7 @@ class Cauchy(distribution.Distribution):
Raises:
TypeError: if `loc` and `scale` have different `dtype`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)]
if validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py
index c77c5fd208..08cdc15828 100644
--- a/tensorflow/contrib/distributions/python/ops/chi2.py
+++ b/tensorflow/contrib/distributions/python/ops/chi2.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import gamma
+from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -83,7 +84,7 @@ class Chi2(gamma.Gamma):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
# Even though all stats of chi2 are defined for valid parameters, this is
# not true in the parent class "gamma." therefore, passing
# allow_nan_stats=True
@@ -119,7 +120,7 @@ class Chi2WithAbsDf(Chi2):
validate_args=False,
allow_nan_stats=True,
name="Chi2WithAbsDf"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[df]) as name:
super(Chi2WithAbsDf, self).__init__(
df=math_ops.floor(
diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py
index a42350430e..6d7d6d307b 100644
--- a/tensorflow/contrib/distributions/python/ops/deterministic.py
+++ b/tensorflow/contrib/distributions/python/ops/deterministic.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
"Deterministic",
@@ -86,7 +87,7 @@ class _BaseDeterministic(distribution.Distribution):
Raises:
ValueError: If `loc` is a scalar.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, atol, rtol]) as name:
loc = ops.convert_to_tensor(loc, name="loc")
if is_vector and validate_args:
diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py
index 53dd42f4c8..446cff6ec2 100644
--- a/tensorflow/contrib/distributions/python/ops/geometric.py
+++ b/tensorflow/contrib/distributions/python/ops/geometric.py
@@ -85,7 +85,7 @@ class Geometric(distribution.Distribution):
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[logits, probs]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits, probs, validate_args=validate_args, name=name)
diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py
index 2c261073ee..ed9ea6f4f3 100644
--- a/tensorflow/contrib/distributions/python/ops/gumbel.py
+++ b/tensorflow/contrib/distributions/python/ops/gumbel.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import util as distribution_util
class _Gumbel(distribution.Distribution):
@@ -124,7 +125,7 @@ class _Gumbel(distribution.Distribution):
Raises:
TypeError: if loc and scale are different dtypes.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py
index d0df2befd6..7e12767f6d 100644
--- a/tensorflow/contrib/distributions/python/ops/half_normal.py
+++ b/tensorflow/contrib/distributions/python/ops/half_normal.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import special_math
+from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -105,7 +106,7 @@ class HalfNormal(distribution.Distribution):
if one or more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py
index fbde55ef31..fa89fff3b7 100644
--- a/tensorflow/contrib/distributions/python/ops/independent.py
+++ b/tensorflow/contrib/distributions/python/ops/independent.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import kullback_leibler
+from tensorflow.python.ops.distributions import util as distribution_util
class Independent(distribution_lib.Distribution):
@@ -116,7 +117,7 @@ class Independent(distribution_lib.Distribution):
ValueError: if `reinterpreted_batch_ndims` exceeds
`distribution.batch_ndims`
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
name = name or "Independent" + distribution.name
self._distribution = distribution
with ops.name_scope(name) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
index 502bd4f493..85e8e10466 100644
--- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
@@ -125,7 +125,7 @@ class InverseGamma(distribution.Distribution):
Raises:
TypeError: if `concentration` and `rate` are different dtypes.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[concentration, rate]) as name:
with ops.control_dependencies([
check_ops.assert_positive(concentration),
@@ -280,7 +280,7 @@ class InverseGammaWithSoftplusConcentrationRate(InverseGamma):
validate_args=False,
allow_nan_stats=True,
name="InverseGammaWithSoftplusConcentrationRate"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[concentration, rate]) as name:
super(InverseGammaWithSoftplusConcentrationRate, self).__init__(
concentration=nn.softplus(concentration,
diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py
index c83b5bc2e3..0103283259 100644
--- a/tensorflow/contrib/distributions/python/ops/logistic.py
+++ b/tensorflow/contrib/distributions/python/ops/logistic.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import util as distribution_util
class Logistic(distribution.Distribution):
@@ -119,7 +120,7 @@ class Logistic(distribution.Distribution):
Raises:
TypeError: if loc and scale are different dtypes.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py
index 2ef294af2e..d54f30dc63 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture.py
@@ -116,7 +116,7 @@ class Mixture(distribution.Distribution):
matching static batch shapes, or all components do not
have matching static event shapes.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
if not isinstance(cat, categorical.Categorical):
raise TypeError("cat must be a Categorical distribution, but saw: %s" %
cat)
diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
index 0b1301e551..c7c90cf875 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
@@ -130,7 +130,7 @@ class MixtureSameFamily(distribution.Distribution):
ValueError: if `mixture_distribution` categories does not equal
`components_distribution` rightmost batch shape.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name) as name:
self._mixture_distribution = mixture_distribution
self._components_distribution = components_distribution
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
index e3236c2db9..cad398582b 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
@@ -193,7 +193,7 @@ class MultivariateNormalDiag(
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name) as name:
with ops.name_scope("init", values=[
loc, scale_diag, scale_identity_multiplier]):
@@ -224,7 +224,7 @@ class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag):
validate_args=False,
allow_nan_stats=True,
name="MultivariateNormalDiagWithSoftplusScale"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[scale_diag]) as name:
super(MultivariateNormalDiagWithSoftplusScale, self).__init__(
loc=loc,
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
index 2f6a6f198c..1c11594df3 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
@@ -215,7 +215,7 @@ class MultivariateNormalDiagPlusLowRank(
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
def _convert_to_tensor(x, name):
return None if x is None else ops.convert_to_tensor(x, name=name)
with ops.name_scope(name) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
index 5d06a396fe..47d7d13cf3 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
@@ -24,6 +24,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -155,7 +156,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL):
Raises:
ValueError: if neither `loc` nor `covariance_matrix` are specified.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
# Convert the covariance_matrix up to a scale_tril and call MVNTriL.
with ops.name_scope(name) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
index 44c92312c7..79916fef8d 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
@@ -170,7 +170,7 @@ class MultivariateNormalLinearOperator(
ValueError: if `scale` is unspecified.
TypeError: if not `scale.dtype.is_floating`
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
if scale is None:
raise ValueError("Missing required `scale` parameter.")
if not scale.dtype.is_floating:
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
index d6f8b731cb..d6b0ed994e 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
@@ -179,7 +179,7 @@ class MultivariateNormalTriL(
Raises:
ValueError: if neither `loc` nor `scale_tril` are specified.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
def _convert_to_tensor(x, name):
return None if x is None else ops.convert_to_tensor(x, name=name)
if loc is None and scale_tril is None:
diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py
index eeaf9c0a5e..1085c56dc8 100644
--- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py
+++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py
@@ -90,7 +90,7 @@ class NegativeBinomial(distribution.Distribution):
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[total_count, logits, probs]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits, probs, validate_args=validate_args, name=name)
diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
index 305b138fdc..a4b9f3b78d 100644
--- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
@@ -115,7 +115,7 @@ class OneHotCategorical(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[logits, probs]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
name=name, logits=logits, probs=probs, validate_args=validate_args,
diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py
index a84aad6fc9..b345394021 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson.py
@@ -93,7 +93,7 @@ class Poisson(distribution.Distribution):
TypeError: if `rate` is not a float-type.
TypeError: if `log_rate` is not a float-type.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[rate]) as name:
if (rate is None) == (log_rate is None):
raise ValueError("Must specify exactly one of `rate` and `log_rate`.")
diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
index 19c99dcee9..fe72091d7d 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
@@ -255,7 +255,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
TypeError: if `quadrature_grid` and `quadrature_probs` have different base
`dtype`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
if loc is not None:
loc = ops.convert_to_tensor(loc, name="loc")
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index eb94760ad7..584d2c385f 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -263,7 +263,7 @@ class QuantizedDistribution(distributions.Distribution):
`Distribution` or continuous.
NotImplementedError: If the base distribution does not implement `cdf`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
values = (
list(distribution.parameters.values()) +
[low, high])
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
index 84c8d29072..0362996e68 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
@@ -165,7 +165,7 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution):
Raises:
ValueError: If both `probs` and `logits` are passed, or if neither.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[logits, probs, temperature]) as name:
with ops.control_dependencies([check_ops.assert_positive(temperature)]
if validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
index 325f41e37c..910c430ae7 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
@@ -162,7 +162,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[logits, probs, temperature]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
index 03828fa612..f04dc8da39 100644
--- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
+++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
@@ -132,7 +132,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution):
if one or more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name,
values=[loc, scale, skewness, tailweight]) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
index af6ff8162b..cd6d749959 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
@@ -395,7 +395,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
ValueError: if `not distribution.is_scalar_batch`.
ValueError: if `not distribution.is_scalar_event`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[mix_loc, temperature]) as name:
if not scale or len(scale) < 2:
raise ValueError("Must specify list (or list-like object) of scale "
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
index e265b5d0f7..3465d66b30 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
@@ -175,7 +175,7 @@ class VectorExponentialDiag(
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name) as name:
with ops.name_scope("init", values=[
loc, scale_diag, scale_identity_multiplier]):
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
index 89136d6760..2c31b01984 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
@@ -175,7 +175,7 @@ class VectorExponentialLinearOperator(
ValueError: if `scale` is unspecified.
TypeError: if not `scale.dtype.is_floating`
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
if scale is None:
raise ValueError("Missing required `scale` parameter.")
if not scale.dtype.is_floating:
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
index 8dd983b750..6a36018d6f 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
@@ -210,7 +210,7 @@ class VectorLaplaceDiag(
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name):
with ops.name_scope("init", values=[
loc, scale_diag, scale_identity_multiplier]):
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
index ec485c95c1..97e5c76d80 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
@@ -191,7 +191,7 @@ class VectorLaplaceLinearOperator(
ValueError: if `scale` is unspecified.
TypeError: if not `scale.dtype.is_floating`
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
if scale is None:
raise ValueError("Missing required `scale` parameter.")
if not scale.dtype.is_floating:
diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
index 1438ede265..ff5ca45257 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
@@ -163,7 +163,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution):
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(
name,
diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
index 7e78ded9df..4742f75218 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
@@ -175,7 +175,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
if one or more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
graph_parents = [df, loc, scale_identity_multiplier, scale_diag,
scale_tril, scale_perturb_factor, scale_perturb_diag]
with ops.name_scope(name) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py
index 91453fed5d..f555867e7f 100644
--- a/tensorflow/contrib/distributions/python/ops/wishart.py
+++ b/tensorflow/contrib/distributions/python/ops/wishart.py
@@ -107,7 +107,7 @@ class _WishartLinearOperator(distribution.Distribution):
ValueError: if df < k, where scale operator event shape is
`(k, k)`
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
self._cholesky_input_output_matrices = cholesky_input_output_matrices
with ops.name_scope(name) as name:
with ops.name_scope("init", values=[df, scale_operator]):
@@ -530,7 +530,7 @@ class WishartCholesky(_WishartLinearOperator):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[scale]) as name:
with ops.name_scope("init", values=[scale]):
scale = ops.convert_to_tensor(scale)
@@ -646,7 +646,7 @@ class WishartFull(_WishartLinearOperator):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name) as name:
with ops.name_scope("init", values=[scale]):
scale = ops.convert_to_tensor(scale)
diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD
new file mode 100644
index 0000000000..638c57d1c9
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/scan/BUILD
@@ -0,0 +1,25 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+cuda_py_test(
+ name = "scan_test",
+ size = "small",
+ srcs = ["scan_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+cuda_py_test(
+ name = "scan_graph_test",
+ size = "small",
+ srcs = ["scan_graph_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py
new file mode 100644
index 0000000000..4661dafbed
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py
@@ -0,0 +1,57 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unit test for tf.scan under graph mode execution."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+import tensorflow as tf
+
+
+class ScanBenchmark(tf.test.Benchmark):
+
+ def runScan(self, n):
+ elems = np.arange(n)
+ start_time = time.time()
+ sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1)
+ with tf.Session() as sess:
+ sess.run(sum_op)
+ wall_time = time.time() - start_time
+
+ self.report_benchmark(
+ name='scan',
+ iters=n,
+ wall_time=wall_time)
+
+ def benchmarkScan32000(self):
+ self.runScan(32000)
+
+ def benchmarkScan1M(self):
+ self.runScan(1000000)
+
+ def benchmarkScan2M(self):
+ self.runScan(2000000)
+
+ def benchmarkScan4M(self):
+ self.runScan(4000000)
+
+ def benchmarkScan8M(self):
+ self.runScan(8000000)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py
new file mode 100644
index 0000000000..b8c7cf1fe5
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/scan/scan_test.py
@@ -0,0 +1,56 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unit test for tf.scan under eager execution."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+import tensorflow as tf
+
+
+class ScanBenchmark(tf.test.Benchmark):
+
+ def runScan(self, n):
+ elems = np.arange(n)
+ start_time = time.time()
+ _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1)
+ wall_time = time.time() - start_time
+
+ self.report_benchmark(
+ name='scan',
+ iters=n,
+ wall_time=wall_time)
+
+ def benchmarkScan2000(self):
+ self.runScan(2000)
+
+ def benchmarkScan4000(self):
+ self.runScan(4000)
+
+ def benchmarkScan8000(self):
+ self.runScan(8000)
+
+ def benchmarkScan16000(self):
+ self.runScan(16000)
+
+ def benchmarkScan32000(self):
+ self.runScan(32000)
+
+if __name__ == '__main__':
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py
index 44828bea50..9af50ee146 100644
--- a/tensorflow/contrib/eager/python/network.py
+++ b/tensorflow/contrib/eager/python/network.py
@@ -23,7 +23,6 @@ import os
import weakref
from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import ops
from tensorflow.python.keras._impl.keras.engine import base_layer as keras_base_layer
from tensorflow.python.layers import base
@@ -33,6 +32,7 @@ from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.util import deprecation
+from tensorflow.python.util import function_utils
# pylint: disable=protected-access
# Explanation for protected-access disable: Network has lots of same-class and
@@ -545,10 +545,10 @@ class Sequential(Network):
def add(self, layer_func):
if isinstance(layer_func, base.Layer):
- args = estimator_util.fn_args(layer_func.call)
+ args = function_utils.fn_args(layer_func.call)
self.track_layer(layer_func)
elif callable(layer_func):
- args = estimator_util.fn_args(layer_func)
+ args = function_utils.fn_args(layer_func)
else:
raise TypeError(
"Sequential.add() takes only tf.layers.Layer objects or callables; "
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index e9a68801ef..df08dc2be6 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -14,6 +14,7 @@ py_library(
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
+ ":baseline",
":boosted_trees",
":dnn",
":dnn_linear_combined",
@@ -30,6 +31,49 @@ py_library(
)
py_library(
+ name = "baseline",
+ srcs = ["python/estimator/baseline.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:baseline",
+ ],
+)
+
+py_test(
+ name = "baseline_test",
+ size = "small",
+ srcs = ["python/estimator/baseline_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "notsan",
+ ],
+ deps = [
+ ":baseline",
+ ":head",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:session",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/estimator:export_export",
+ "//tensorflow/python/estimator:metric_keys",
+ "//tensorflow/python/estimator:numpy_io",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python/ops/losses",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "boosted_trees",
srcs = ["python/estimator/boosted_trees.py"],
srcs_version = "PY2AND3",
@@ -322,9 +366,9 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:util",
"//tensorflow/python/estimator:dnn",
"//tensorflow/python/estimator:linear",
- "//tensorflow/python/estimator:util",
],
)
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index ec502f86dd..32a0f2545d 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.contrib.estimator.python.estimator.baseline import *
from tensorflow.contrib.estimator.python.estimator.boosted_trees import *
from tensorflow.contrib.estimator.python.estimator.dnn import *
from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import *
@@ -45,6 +46,7 @@ _allowed_symbols = [
'multi_label_head',
'poisson_regression_head',
'regression_head',
+ 'BaselineEstimator',
'DNNEstimator',
'DNNLinearCombinedEstimator',
'LinearEstimator',
diff --git a/tensorflow/contrib/estimator/python/estimator/baseline.py b/tensorflow/contrib/estimator/python/estimator/baseline.py
new file mode 100644
index 0000000000..beffbee730
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/baseline.py
@@ -0,0 +1,98 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Baseline estimators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator.canned import baseline
+
+
+class BaselineEstimator(estimator.Estimator):
+ """An estimator that can establish a simple baseline.
+
+ The estimator uses a user-specified head.
+
+ This estimator ignores feature values and will learn to predict the average
+ value of each label. E.g. for single-label classification problems, this will
+ predict the probability distribution of the classes as seen in the labels.
+ For multi-label classification problems, it will predict the ratio of examples
+ that contain each class.
+
+ Example:
+
+ ```python
+
+ # Build baseline multi-label classifier.
+ estimator = BaselineEstimator(
+ head=tf.contrib.estimator.multi_label_head(n_classes=3))
+
+ # Input builders
+ def input_fn_train: # returns x, y (where y represents label's class index).
+ pass
+
+ def input_fn_eval: # returns x, y (where y represents label's class index).
+ pass
+
+ # Fit model.
+ estimator.train(input_fn=input_fn_train)
+
+ # Evaluates cross entropy between the test and train labels.
+ loss = classifier.evaluate(input_fn=input_fn_eval)["loss"]
+
+ # For each class, predicts the ratio of training examples that contain the
+ # class.
+ predictions = classifier.predict(new_samples)
+
+ ```
+
+ Input of `train` and `evaluate` should have following features,
+ otherwise there will be a `KeyError`:
+
+ * if `weight_column` passed to the `head` constructor is not `None`, a feature
+ with `key=weight_column` whose value is a `Tensor`.
+ """
+
+ def __init__(self,
+ head,
+ model_dir=None,
+ optimizer='Ftrl',
+ config=None):
+ """Initializes a BaselineEstimator instance.
+
+ Args:
+ head: A `_Head` instance constructed with a method such as
+ `tf.contrib.estimator.multi_label_head`.
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator to
+ continue training a previously saved model.
+ optimizer: String, `tf.Optimizer` object, or callable that creates the
+ optimizer to use for training. If not specified, will use
+ `FtrlOptimizer` with a default learning rate of 0.3.
+ config: `RunConfig` object to configure the runtime settings.
+ """
+ def _model_fn(features, labels, mode, config):
+ return baseline._baseline_model_fn( # pylint: disable=protected-access
+ features=features,
+ labels=labels,
+ mode=mode,
+ head=head,
+ optimizer=optimizer,
+ config=config)
+ super(BaselineEstimator, self).__init__(
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config)
diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
new file mode 100644
index 0000000000..d0e3e670f7
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
@@ -0,0 +1,430 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for baseline.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+import tempfile
+
+import numpy as np
+import six
+
+from tensorflow.contrib.estimator.python.estimator import baseline
+from tensorflow.contrib.estimator.python.estimator import head as head_lib
+from tensorflow.python.client import session as tf_session
+from tensorflow.python.estimator.canned import metric_keys
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import saver
+
+# Names of variables created by model.
+BIAS_NAME = 'baseline/bias'
+
+
+def assert_close(expected, actual, rtol=1e-04, name='assert_close'):
+ with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:
+ expected = ops.convert_to_tensor(expected, name='expected')
+ actual = ops.convert_to_tensor(actual, name='actual')
+ rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected)
+ rtol = ops.convert_to_tensor(rtol, name='rtol')
+ return check_ops.assert_less(
+ rdiff,
+ rtol,
+ data=('Condition expected =~ actual did not hold element-wise:'
+ 'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff,
+ 'rtol = ', rtol,),
+ name=scope)
+
+
+def save_variables_to_ckpt(model_dir):
+ init_all_op = [variables.global_variables_initializer()]
+ with tf_session.Session() as sess:
+ sess.run(init_all_op)
+ saver.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
+
+
+def _baseline_estimator_fn(
+ weight_column=None, label_dimension=1, *args, **kwargs):
+ """Returns a BaselineEstimator that uses regression_head."""
+ return baseline.BaselineEstimator(
+ head=head_lib.regression_head(
+ weight_column=weight_column, label_dimension=label_dimension,
+ # Tests in core (from which this test inherits) test the sum loss.
+ loss_reduction=losses.Reduction.SUM),
+ *args, **kwargs)
+
+
+class BaselineEstimatorEvaluationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def test_evaluation_batch(self):
+ """Tests evaluation for batch_size==2."""
+ with ops.Graph().as_default():
+ variables.Variable([13.0], name=BIAS_NAME)
+ variables.Variable(
+ 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir)
+ eval_metrics = baseline_estimator.evaluate(
+ input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1)
+
+ # Logit is bias = 13, while label is 10.
+ # Loss per example is 3**2 = 9.
+ # Training loss is the sum over batch = 9 + 9 = 18
+ # Average loss is the average over batch = 9
+ self.assertDictEqual({
+ metric_keys.MetricKeys.LOSS: 18.,
+ metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ ops.GraphKeys.GLOBAL_STEP: 100
+ }, eval_metrics)
+
+ def test_evaluation_weights(self):
+ """Tests evaluation with weights."""
+ with ops.Graph().as_default():
+ variables.Variable([13.0], name=BIAS_NAME)
+ variables.Variable(
+ 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ def _input_fn():
+ features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))}
+ labels = ((10.,), (10.,))
+ return features, labels
+
+ baseline_estimator = _baseline_estimator_fn(
+ weight_column='weights',
+ model_dir=self._model_dir)
+ eval_metrics = baseline_estimator.evaluate(input_fn=_input_fn, steps=1)
+
+ # Logit is bias = 13, while label is 10.
+ # Loss per example is 3**2 = 9.
+ # Training loss is the weighted sum over batch = 9 + 2*9 = 27
+ # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9
+ self.assertDictEqual({
+ metric_keys.MetricKeys.LOSS: 27.,
+ metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ ops.GraphKeys.GLOBAL_STEP: 100
+ }, eval_metrics)
+
+ def test_evaluation_for_multi_dimensions(self):
+ label_dim = 2
+ with ops.Graph().as_default():
+ variables.Variable([46.0, 58.0], name=BIAS_NAME)
+ variables.Variable(100, name='global_step', dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ baseline_estimator = _baseline_estimator_fn(
+ label_dimension=label_dim,
+ model_dir=self._model_dir)
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'age': np.array([[2., 4., 5.]]),
+ },
+ y=np.array([[46., 58.]]),
+ batch_size=1,
+ num_epochs=None,
+ shuffle=False)
+ eval_metrics = baseline_estimator.evaluate(input_fn=input_fn, steps=1)
+
+ self.assertItemsEqual(
+ (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
+ ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys())
+
+ # Logit is bias which is [46, 58]
+ self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])
+
+
+class BaselineEstimatorPredictTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def test_1d(self):
+ """Tests predict when all variables are one-dimensional."""
+ with ops.Graph().as_default():
+ variables.Variable([.2], name=BIAS_NAME)
+ variables.Variable(100, name='global_step', dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir)
+
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': np.array([[2.]])},
+ y=None,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False)
+ predictions = baseline_estimator.predict(input_fn=predict_input_fn)
+ predicted_scores = list([x['predictions'] for x in predictions])
+ # x * weight + bias = 2. * 10. + .2 = 20.2
+ self.assertAllClose([[.2]], predicted_scores)
+
+ def testMultiDim(self):
+ """Tests predict when all variables are multi-dimenstional."""
+ batch_size = 2
+ label_dimension = 3
+ with ops.Graph().as_default():
+ variables.Variable( # shape=[label_dimension]
+ [.2, .4, .6], name=BIAS_NAME)
+ variables.Variable(100, name='global_step', dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ baseline_estimator = _baseline_estimator_fn(
+ label_dimension=label_dimension,
+ model_dir=self._model_dir)
+
+ predict_input_fn = numpy_io.numpy_input_fn(
+ # x shape=[batch_size, x_dim]
+ x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])},
+ y=None,
+ batch_size=batch_size,
+ num_epochs=1,
+ shuffle=False)
+ predictions = baseline_estimator.predict(input_fn=predict_input_fn)
+ predicted_scores = list([x['predictions'] for x in predictions])
+ # score = bias, shape=[batch_size, label_dimension]
+ self.assertAllClose([[0.2, 0.4, 0.6], [0.2, 0.4, 0.6]],
+ predicted_scores)
+
+
+class BaselineEstimatorIntegrationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, label_dimension, prediction_length):
+ feature_columns = [
+ feature_column_lib.numeric_column('x', shape=(input_dimension,))
+ ]
+ est = _baseline_estimator_fn(
+ label_dimension=label_dimension,
+ model_dir=self._model_dir)
+
+ # TRAIN
+ # learn y = x
+ est.train(train_input_fn, steps=200)
+
+ # EVALUTE
+ scores = est.evaluate(eval_input_fn)
+ self.assertEqual(200, scores[ops.GraphKeys.GLOBAL_STEP])
+ self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))
+
+ # PREDICT
+ predictions = np.array(
+ [x['predictions'] for x in est.predict(predict_input_fn)])
+ self.assertAllEqual((prediction_length, label_dimension), predictions.shape)
+
+ # EXPORT
+ feature_spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+ serving_input_receiver_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ def test_numpy_input_fn(self):
+ """Tests complete flow with numpy_input_fn."""
+ label_dimension = 2
+ input_dimension = label_dimension
+ batch_size = 10
+ prediction_length = batch_size
+ data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, label_dimension)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=data,
+ batch_size=batch_size,
+ num_epochs=1,
+ shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=None,
+ batch_size=batch_size,
+ num_epochs=1,
+ shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=input_dimension,
+ label_dimension=label_dimension,
+ prediction_length=prediction_length)
+
+
+class BaselineEstimatorTrainingTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _mock_optimizer(self, expected_loss=None):
+ expected_var_names = [
+ '%s:0' % BIAS_NAME
+ ]
+
+ def _minimize(loss, global_step=None, var_list=None):
+ trainable_vars = var_list or ops.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertItemsEqual(expected_var_names,
+ [var.name for var in trainable_vars])
+
+ # Verify loss. We can't check the value directly, so we add an assert op.
+ self.assertEquals(0, loss.shape.ndims)
+ if expected_loss is None:
+ if global_step is not None:
+ return distribute_lib.increment_var(global_step)
+ return control_flow_ops.no_op()
+ assert_loss = assert_close(
+ math_ops.to_float(expected_loss, name='expected'),
+ loss,
+ name='assert_loss')
+ with ops.control_dependencies((assert_loss,)):
+ if global_step is not None:
+ return distribute_lib.increment_var(global_step)
+ return control_flow_ops.no_op()
+
+ mock_optimizer = test.mock.NonCallableMock(
+ spec=optimizer.Optimizer,
+ wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer'))
+ mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize)
+
+ # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks.
+ # So, return mock_optimizer itself for deepcopy.
+ mock_optimizer.__deepcopy__ = lambda _: mock_optimizer
+ return mock_optimizer
+
+ def _assert_checkpoint(self,
+ label_dimension,
+ expected_global_step,
+ expected_bias=None):
+ shapes = {
+ name: shape
+ for (name, shape) in checkpoint_utils.list_variables(self._model_dir)
+ }
+
+ self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
+ self.assertEqual(expected_global_step,
+ checkpoint_utils.load_variable(self._model_dir,
+ ops.GraphKeys.GLOBAL_STEP))
+
+ self.assertEqual([label_dimension], shapes[BIAS_NAME])
+ if expected_bias is not None:
+ self.assertEqual(expected_bias,
+ checkpoint_utils.load_variable(self._model_dir,
+ BIAS_NAME))
+
+ def testFromScratch(self):
+ # Create BaselineRegressor.
+ label = 5.
+ age = 17
+ # loss = (logits - label)^2 = (0 - 5.)^2 = 25.
+ mock_optimizer = self._mock_optimizer(expected_loss=25.)
+ baseline_estimator = _baseline_estimator_fn(
+ model_dir=self._model_dir,
+ optimizer=mock_optimizer)
+ self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+ # Train for a few steps, and validate optimizer and final checkpoint.
+ num_steps = 10
+ baseline_estimator.train(
+ input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+ self.assertEqual(1, mock_optimizer.minimize.call_count)
+ self._assert_checkpoint(
+ label_dimension=1,
+ expected_global_step=num_steps,
+ expected_bias=[0.])
+
+ def testFromCheckpoint(self):
+ # Create initial checkpoint.
+ bias = 7.0
+ initial_global_step = 100
+ with ops.Graph().as_default():
+ variables.Variable([bias], name=BIAS_NAME)
+ variables.Variable(
+ initial_global_step,
+ name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ # logits = bias = 6.
+ # loss = (logits - label)^2 = (7 - 5)^2 = 4
+ mock_optimizer = self._mock_optimizer(expected_loss=4.)
+ baseline_estimator = _baseline_estimator_fn(
+ model_dir=self._model_dir,
+ optimizer=mock_optimizer)
+ self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+ # Train for a few steps, and validate optimizer and final checkpoint.
+ num_steps = 10
+ baseline_estimator.train(
+ input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps)
+ self.assertEqual(1, mock_optimizer.minimize.call_count)
+ self._assert_checkpoint(
+ label_dimension=1,
+ expected_global_step=initial_global_step + num_steps,
+ expected_bias=[bias])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py
index cf6e3329d2..7ff25b95c0 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn.py
@@ -93,7 +93,7 @@ class DNNEstimator(estimator.Estimator):
dropout=None,
input_layer_partitioner=None,
config=None):
- """Initializes a `DNNClassifier` instance.
+ """Initializes a `DNNEstimator` instance.
Args:
head: A `_Head` instance constructed with a method such as
diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py
index e7e366a3f2..03cf6f107c 100644
--- a/tensorflow/contrib/estimator/python/estimator/export.py
+++ b/tensorflow/contrib/estimator/python/estimator/export.py
@@ -60,38 +60,16 @@ def export_saved_model_for_mode(
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
+ weights = graph.get_tensor_by_name(''linear/linear_model/age/weights')
...
```
- This method takes an input_receiver_fn and mode. For the mode passed in,
- this method builds a new graph by calling the input_receiver_fn to obtain
- feature and label `Tensor`s. Next, this method calls the `Estimator`'s
- model_fn in the passed mode to generate the model graph based on
- those features and labels, and restores the given checkpoint
- (or, lacking that, the most recent checkpoint) into the graph.
- Finally, it creates a timestamped export directory below the
- export_dir_base, and writes a `SavedModel` into it containing
- the `MetaGraphDef` for the given mode and its associated signatures.
-
- For prediction, the exported `MetaGraphDef` will provide one `SignatureDef`
- for each element of the export_outputs dict returned from the model_fn,
- named using the same keys. One of these keys is always
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
- signature will be served when a serving request does not specify one.
- For each signature, the outputs are provided by the corresponding
- `ExportOutput`s, and the inputs are always the input receivers provided by
- the serving_input_receiver_fn.
+ This method is a wrapper for _export_all_saved_models, and wraps a raw
+ input_receiver_fn in a dictionary to pass in to that function.
+ See _export_all_saved_models for full docs.
- For training and evaluation, the train_op is stored in an extra collection,
- and loss, metrics, and predictions are included in a SignatureDef for the
- mode in question.
-
- Extra assets may be written into the SavedModel via the assets_extra
- argument. This should be a dict, where each key gives a destination path
- (including the filename) relative to the assets.extra directory. The
- corresponding value gives the full path of the source file to be copied.
- For example, the simple case of copying a single file without renaming it
- is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
+ See tf.contrib.estimator.export_saved_model_for_mode for the currently
+ exposed version of this function.
Args:
estimator: an instance of tf.estimator.Estimator
@@ -138,10 +116,39 @@ def export_all_saved_models(
# pylint: disable=line-too-long
"""Exports requested train/eval/predict graphs as separate SavedModels.
- This is a wrapper around export_saved_model_for_mode that accepts
- multiple modes simultaneously and creates directories for each under
- export_dir_base. See `Estimator.export_saved_model_for_mode` for
- further details as to how the export works for each mode.
+ See tf.contrib.estimator.export_all_saved_models for the currently
+ exposed version of this function.
+
+ For each mode passed in via the input_receiver_fn_map,
+ this method builds a new graph by calling the input_receiver_fn to obtain
+ feature and label `Tensor`s. Next, this method calls the `Estimator`'s
+ model_fn in the passed mode to generate the model graph based on
+ those features and labels, and restores the given checkpoint
+ (or, lacking that, the most recent checkpoint) into the graph.
+ Only one of the modes is used for saving variables to the SavedModel
+ (order of preference: TRAIN, EVAL, then PREDICT), such that up to three
+ MetaGraphDefs are saved with a single set of variables in a single
+ SavedModel directory.
+
+ For prediction, the exported `MetaGraphDef` will provide one `SignatureDef`
+ for each element of the export_outputs dict returned from the model_fn,
+ named using the same keys. One of these keys is always
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
+ signature will be served when a serving request does not specify one.
+ For each signature, the outputs are provided by the corresponding
+ `ExportOutput`s, and the inputs are always the input receivers provided by
+ the serving_input_receiver_fn.
+
+ For training and evaluation, the train_op is stored in an extra collection,
+ and loss, metrics, and predictions are included in a SignatureDef for the
+ mode in question.
+
+ Extra assets may be written into the SavedModel via the assets_extra
+ argument. This should be a dict, where each key gives a destination path
+ (including the filename) relative to the assets.extra directory. The
+ corresponding value gives the full path of the source file to be copied.
+ For example, the simple case of copying a single file without renaming it
+ is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
Sample usage:
```python
@@ -166,7 +173,7 @@ def export_all_saved_models(
model_fn_lib.ModeKeys.PREDICT: serve_rcvr_fn,
}
- export_dirs = tf.contrib.estimator.export_all_saved_models(
+ export_dir = tf.contrib.estimator.export_all_saved_models(
classifier,
export_dir_base='my_model/',
input_receiver_fn_map=rcvr_fn_map)
@@ -175,8 +182,8 @@ def export_all_saved_models(
# can be used for serving, analysis with TFMA, or directly loaded in.
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
- loader.load(sess, [tag_constants.TRAINING],
- export_dirs[tf.estimator.ModeKeys.TRAIN])
+ loader.load(sess, [tag_constants.TRAINING], export_dir)
+ weights = graph.get_tensor_by_name('linear/linear_model/age/weights')
...
```
diff --git a/tensorflow/contrib/estimator/python/estimator/export_test.py b/tensorflow/contrib/estimator/python/estimator/export_test.py
index 89d02582e1..050821ee67 100644
--- a/tensorflow/contrib/estimator/python/estimator/export_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/export_test.py
@@ -166,12 +166,9 @@ class EstimatorExportTest(test.TestCase):
input_receiver_fn_map = {
model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- self.assertEqual(len(export_dirs), 1)
- # Restore, to validate that the export was well-formed.
- export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.SERVING], export_dir)
@@ -188,12 +185,9 @@ class EstimatorExportTest(test.TestCase):
input_receiver_fn_map = {
model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- self.assertEqual(len(export_dirs), 1)
- # Restore, to validate that the export was well-formed.
- export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
@@ -211,12 +205,9 @@ class EstimatorExportTest(test.TestCase):
input_receiver_fn_map = {
model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- self.assertEqual(len(export_dirs), 1)
- # Restore, to validate that the export was well-formed.
- export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.EVAL], export_dir)
@@ -235,12 +226,9 @@ class EstimatorExportTest(test.TestCase):
model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- self.assertEqual(len(export_dirs), 2)
- # Restore, to validate that the export was well-formed.
- export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
@@ -249,7 +237,7 @@ class EstimatorExportTest(test.TestCase):
self.assertFalse('eval_multiplied' in graph_ops)
self.assertTrue('feature_x' in graph_ops)
self.assertTrue('weight' in graph_ops)
- export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL]
+
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.EVAL], export_dir)
@@ -270,12 +258,11 @@ class EstimatorExportTest(test.TestCase):
model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(),
model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
# Restore, to validate that the export was well-formed.
- for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items():
- export_dir = export_dirs[mode]
+ for tag_set in model_fn_lib.EXPORT_TAG_MAP.values():
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, tag_set, export_dir)
@@ -292,10 +279,9 @@ class EstimatorExportTest(test.TestCase):
model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
@@ -303,7 +289,6 @@ class EstimatorExportTest(test.TestCase):
self.assertTrue('later_var' in graph_ops)
self.assertTrue('weight' in graph_ops)
- export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.SERVING], export_dir)
@@ -319,10 +304,9 @@ class EstimatorExportTest(test.TestCase):
model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
@@ -332,7 +316,6 @@ class EstimatorExportTest(test.TestCase):
collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertEqual(3, collection_vars[-1].eval())
- export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.SERVING], export_dir)
@@ -360,16 +343,15 @@ class EstimatorExportTest(test.TestCase):
# Perform the export.
export_dir_base = os.path.join(
compat.as_bytes(tmpdir), compat.as_bytes('export'))
- export_dirs = contrib_export.export_all_saved_models(
+ export_dir = contrib_export.export_all_saved_models(
est, export_dir_base, input_receiver_fn_map)
# Check that all the files are in the right places.
self.assertTrue(gfile.Exists(export_dir_base))
- for _, export_dir in export_dirs.items():
- self._validate_exported_files(export_dir)
+ self._validate_exported_files(export_dir)
- return export_dirs, tmpdir
+ return export_dir, tmpdir
def _validate_exported_files(self, export_dir):
self.assertTrue(gfile.Exists(export_dir))
diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py
index 201699ed77..bf08be09e7 100644
--- a/tensorflow/contrib/estimator/python/estimator/extenders.py
+++ b/tensorflow/contrib/estimator/python/estimator/extenders.py
@@ -22,12 +22,12 @@ import six
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export.export_output import PredictOutput
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.ops import clip_ops
from tensorflow.python.training import optimizer as optimizer_lib
+from tensorflow.python.util import function_utils
_VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config'])
@@ -330,7 +330,7 @@ class _TransformGradients(optimizer_lib.Optimizer):
def _verify_metric_fn_args(metric_fn):
- args = set(estimator_util.fn_args(metric_fn))
+ args = set(function_utils.fn_args(metric_fn))
invalid_args = list(args - _VALID_METRIC_FN_ARGS)
if invalid_args:
raise ValueError('metric_fn (%s) has following not expected args: %s' %
@@ -339,7 +339,7 @@ def _verify_metric_fn_args(metric_fn):
def _call_metric_fn(metric_fn, features, labels, predictions, config):
"""Calls metric fn with proper arguments."""
- metric_fn_args = estimator_util.fn_args(metric_fn)
+ metric_fn_args = function_utils.fn_args(metric_fn)
kwargs = {}
if 'features' in metric_fn_args:
kwargs['features'] = features
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index 109fdd3883..8b97f86db1 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import six
+
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import metric_keys
@@ -72,6 +74,33 @@ def multi_class_head(n_classes,
shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to
the input labels before passing them to `loss_fn`.
+ The head can be used with a canned estimator. Example:
+
+ ```python
+ my_head = tf.contrib.estimator.multi_class_head(n_classes=3)
+ my_estimator = tf.contrib.estimator.DNNEstimator(
+ head=my_head,
+ hidden_units=...,
+ feature_columns=...)
+ ```
+
+ It can also be used with a custom `model_fn`. Example:
+
+ ```python
+ def _my_model_fn(features, labels, mode):
+ my_head = tf.contrib.estimator.multi_class_head(n_classes=3)
+ logits = tf.keras.Model(...)(features)
+
+ return my_head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ optimizer=tf.AdagradOptimizer(learning_rate=0.1),
+ logits=logits)
+
+ my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)
+ ```
+
Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
`binary_classification_head`).
@@ -139,6 +168,33 @@ def binary_classification_head(
shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to
the input labels before passing them to `loss_fn`.
+ The head can be used with a canned estimator. Example:
+
+ ```python
+ my_head = tf.contrib.estimator.binary_classification_head()
+ my_estimator = tf.contrib.estimator.DNNEstimator(
+ head=my_head,
+ hidden_units=...,
+ feature_columns=...)
+ ```
+
+ It can also be used with a custom `model_fn`. Example:
+
+ ```python
+ def _my_model_fn(features, labels, mode):
+ my_head = tf.contrib.estimator.binary_classification_head()
+ logits = tf.keras.Model(...)(features)
+
+ return my_head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ optimizer=tf.AdagradOptimizer(learning_rate=0.1),
+ logits=logits)
+
+ my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)
+ ```
+
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -211,6 +267,33 @@ def regression_head(weight_column=None,
https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function
Namely, for poisson regression, set `inverse_link_fn=tf.exp`.
+ The head can be used with a canned estimator. Example:
+
+ ```python
+ my_head = tf.contrib.estimator.regression_head()
+ my_estimator = tf.contrib.estimator.DNNEstimator(
+ head=my_head,
+ hidden_units=...,
+ feature_columns=...)
+ ```
+
+ It can also be used with a custom `model_fn`. Example:
+
+ ```python
+ def _my_model_fn(features, labels, mode):
+ my_head = tf.contrib.estimator.regression_head()
+ logits = tf.keras.Model(...)(features)
+
+ return my_head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ optimizer=tf.AdagradOptimizer(learning_rate=0.1),
+ logits=logits)
+
+ my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)
+ ```
+
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -270,6 +353,33 @@ def poisson_regression_head(
This is implemented as a generalized linear model, see
https://en.wikipedia.org/wiki/Generalized_linear_model.
+ The head can be used with a canned estimator. Example:
+
+ ```python
+ my_head = tf.contrib.estimator.poisson_regression_head()
+ my_estimator = tf.contrib.estimator.DNNEstimator(
+ head=my_head,
+ hidden_units=...,
+ feature_columns=...)
+ ```
+
+ It can also be used with a custom `model_fn`. Example:
+
+ ```python
+ def _my_model_fn(features, labels, mode):
+ my_head = tf.contrib.estimator.poisson_regression_head()
+ logits = tf.keras.Model(...)(features)
+
+ return my_head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ optimizer=tf.AdagradOptimizer(learning_rate=0.1),
+ logits=logits)
+
+ my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)
+ ```
+
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -337,6 +447,33 @@ def logistic_regression_head(
This is implemented as a generalized linear model, see
https://en.wikipedia.org/wiki/Generalized_linear_model.
+ The head can be used with a canned estimator. Example:
+
+ ```python
+ my_head = tf.contrib.estimator.logistic_regression_head()
+ my_estimator = tf.contrib.estimator.DNNEstimator(
+ head=my_head,
+ hidden_units=...,
+ feature_columns=...)
+ ```
+
+ It can also be used with a custom `model_fn`. Example:
+
+ ```python
+ def _my_model_fn(features, labels, mode):
+ my_head = tf.contrib.estimator.logistic_regression_head()
+ logits = tf.keras.Model(...)(features)
+
+ return my_head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ optimizer=tf.AdagradOptimizer(learning_rate=0.1),
+ logits=logits)
+
+ my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)
+ ```
+
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -375,6 +512,7 @@ def multi_label_head(n_classes,
label_vocabulary=None,
loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
loss_fn=None,
+ classes_for_class_based_metrics=None,
name=None):
"""Creates a `_Head` for multi-label classification.
@@ -406,6 +544,33 @@ def multi_label_head(n_classes,
shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies
`label_vocabulary` to the input labels before passing them to `loss_fn`.
+ The head can be used with a canned estimator. Example:
+
+ ```python
+ my_head = tf.contrib.estimator.multi_label_head(n_classes=3)
+ my_estimator = tf.contrib.estimator.DNNEstimator(
+ head=my_head,
+ hidden_units=...,
+ feature_columns=...)
+ ```
+
+ It can also be used with a custom `model_fn`. Example:
+
+ ```python
+ def _my_model_fn(features, labels, mode):
+ my_head = tf.contrib.estimator.multi_label_head(n_classes=3)
+ logits = tf.keras.Model(...)(features)
+
+ return my_head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ optimizer=tf.AdagradOptimizer(learning_rate=0.1),
+ logits=logits)
+
+ my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)
+ ```
+
Args:
n_classes: Number of classes, must be greater than 1 (for 1 class, use
`binary_classification_head`).
@@ -427,6 +592,10 @@ def multi_label_head(n_classes,
reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely
weighted sum of losses divided by batch size. See `tf.losses.Reduction`.
loss_fn: Optional loss function.
+ classes_for_class_based_metrics: List of integer class IDs or string class
+ names for which per-class metrics are evaluated. If integers, all must be
+ in the range `[0, n_classes - 1]`. If strings, all must be in
+ `label_vocabulary`.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -434,8 +603,8 @@ def multi_label_head(n_classes,
An instance of `_Head` for multi-label classification.
Raises:
- ValueError: if `n_classes`, `thresholds`, `loss_reduction` or `loss_fn` is
- invalid.
+ ValueError: if `n_classes`, `thresholds`, `loss_reduction`, `loss_fn` or
+ `metric_class_ids` is invalid.
"""
thresholds = tuple(thresholds) if thresholds else tuple()
if n_classes is None or n_classes < 2:
@@ -460,10 +629,31 @@ def multi_label_head(n_classes,
if (loss_reduction not in losses.Reduction.all() or
loss_reduction == losses.Reduction.NONE):
raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
+ classes_for_class_based_metrics = tuple(
+ [] if classes_for_class_based_metrics is None
+ else classes_for_class_based_metrics)
+ if classes_for_class_based_metrics:
+ if isinstance(classes_for_class_based_metrics[0], six.string_types):
+ if not label_vocabulary:
+ raise ValueError(
+ 'label_vocabulary must be provided when '
+ 'classes_for_class_based_metrics are sting.')
+ class_ids = []
+ for class_string in classes_for_class_based_metrics:
+ class_ids.append(label_vocabulary.index(class_string))
+ classes_for_class_based_metrics = tuple(class_ids)
+ else:
+ for class_id in classes_for_class_based_metrics:
+ if (class_id < 0) or (class_id >= n_classes):
+ raise ValueError(
+ 'All classes_for_class_based_metrics must be in range [0, {}]. '
+ 'Given: {}'.format(n_classes - 1, class_id))
return _MultiLabelHead(
n_classes=n_classes, weight_column=weight_column, thresholds=thresholds,
label_vocabulary=label_vocabulary, loss_reduction=loss_reduction,
- loss_fn=loss_fn, name=name)
+ loss_fn=loss_fn,
+ classes_for_class_based_metrics=classes_for_class_based_metrics,
+ name=name)
class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
@@ -476,6 +666,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
label_vocabulary=None,
loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
loss_fn=None,
+ classes_for_class_based_metrics=None,
name=None):
self._n_classes = n_classes
self._weight_column = weight_column
@@ -483,6 +674,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
self._label_vocabulary = label_vocabulary
self._loss_reduction = loss_reduction
self._loss_fn = loss_fn
+ self._classes_for_class_based_metrics = classes_for_class_based_metrics
self._name = name
@property
@@ -737,4 +929,36 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
weights=weights,
threshold=threshold,
name=recall_key))
+ for class_id in self._classes_for_class_based_metrics:
+ batch_rank = array_ops.rank(probabilities) - 1
+ begin = array_ops.concat(
+ [array_ops.zeros([batch_rank], dtype=dtypes.int32), [class_id]],
+ axis=0)
+ size = array_ops.concat(
+ [-1 * array_ops.ones([batch_rank], dtype=dtypes.int32), [1]],
+ axis=0)
+ class_probabilities = array_ops.slice(
+ probabilities, begin=begin, size=size)
+ class_labels = array_ops.slice(labels, begin=begin, size=size)
+ prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id
+ metric_ops[head_lib._summary_key(self._name, prob_key)] = ( # pylint:disable=protected-access
+ head_lib._predictions_mean( # pylint:disable=protected-access
+ predictions=class_probabilities,
+ weights=weights,
+ name=prob_key))
+ auc_key = keys.AUC_AT_CLASS % class_id
+ metric_ops[head_lib._summary_key(self._name, auc_key)] = ( # pylint:disable=protected-access
+ head_lib._auc( # pylint:disable=protected-access
+ labels=class_labels,
+ predictions=class_probabilities,
+ weights=weights,
+ name=auc_key))
+ auc_pr_key = keys.AUC_PR_AT_CLASS % class_id
+ metric_ops[head_lib._summary_key(self._name, auc_pr_key)] = ( # pylint:disable=protected-access
+ head_lib._auc( # pylint:disable=protected-access
+ labels=class_labels,
+ predictions=class_probabilities,
+ weights=weights,
+ curve='PR',
+ name=auc_pr_key))
return metric_ops
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index 19b86df556..d6c158608b 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -175,6 +175,21 @@ class MultiLabelHead(test.TestCase):
r'loss_fn has unexpected args: \[\'name\'\]'):
head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn)
+ def test_classes_for_class_based_metrics_invalid(self):
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'All classes_for_class_based_metrics must be in range \[0, 2\]\. '
+ r'Given: -1'):
+ head_lib.multi_label_head(
+ n_classes=3, classes_for_class_based_metrics=[2, -1])
+
+ def test_classes_for_class_based_metrics_string_invalid(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'\'z\' is not in list'):
+ head_lib.multi_label_head(
+ n_classes=3, label_vocabulary=['a', 'b', 'c'],
+ classes_for_class_based_metrics=['c', 'z'])
+
def test_name(self):
head = head_lib.multi_label_head(n_classes=4, name='foo')
self.assertEqual('foo', head.name)
@@ -591,6 +606,81 @@ class MultiLabelHead(test.TestCase):
expected_loss=expected_loss,
expected_metrics=expected_metrics)
+ def test_eval_with_classes_for_class_based_metrics(self):
+ head = head_lib.multi_label_head(
+ n_classes=2, classes_for_class_based_metrics=[0, 1])
+
+ logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
+ labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ # loss = labels * -log(sigmoid(logits)) +
+ # (1 - labels) * -log(1 - sigmoid(logits))
+ # Sum over examples, divide by batch_size.
+ expected_loss = 0.5 * np.sum(
+ _sigmoid_cross_entropy(labels=labels, logits=logits))
+
+ keys = metric_keys.MetricKeys
+ expected_metrics = {
+ # Average loss over examples.
+ keys.LOSS_MEAN: expected_loss,
+ # auc and auc_pr cannot be reliably calculated for only 4 samples, but
+ # this assert tests that the algorithm remains consistent.
+ keys.AUC: 0.3333,
+ keys.AUC_PR: 0.7639,
+ keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2.,
+ keys.AUC_AT_CLASS % 0: 0.,
+ keys.AUC_PR_AT_CLASS % 0: 1.,
+ keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2.,
+ keys.AUC_AT_CLASS % 1: 1.,
+ keys.AUC_PR_AT_CLASS % 1: 1.,
+ }
+
+ self._test_eval(
+ head=head,
+ logits=logits,
+ labels=labels,
+ expected_loss=expected_loss,
+ expected_metrics=expected_metrics)
+
+ def test_eval_with_classes_for_class_based_metrics_string(self):
+ head = head_lib.multi_label_head(
+ n_classes=2, label_vocabulary=['a', 'b'],
+ classes_for_class_based_metrics=['a', 'b'])
+
+ logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
+ labels = sparse_tensor.SparseTensor(
+ values=['a', 'a', 'b'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ labels_onehot = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ # loss = labels * -log(sigmoid(logits)) +
+ # (1 - labels) * -log(1 - sigmoid(logits))
+ # Sum over examples, divide by batch_size.
+ expected_loss = 0.5 * np.sum(
+ _sigmoid_cross_entropy(labels=labels_onehot, logits=logits))
+
+ keys = metric_keys.MetricKeys
+ expected_metrics = {
+ # Average loss over examples.
+ keys.LOSS_MEAN: expected_loss,
+ # auc and auc_pr cannot be reliably calculated for only 4 samples, but
+ # this assert tests that the algorithm remains consistent.
+ keys.AUC: 0.3333,
+ keys.AUC_PR: 0.7639,
+ keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2.,
+ keys.AUC_AT_CLASS % 0: 0.,
+ keys.AUC_PR_AT_CLASS % 0: 1.,
+ keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2.,
+ keys.AUC_AT_CLASS % 1: 1.,
+ keys.AUC_PR_AT_CLASS % 1: 1.,
+ }
+
+ self._test_eval(
+ head=head,
+ logits=logits,
+ labels=labels,
+ expected_loss=expected_loss,
+ expected_metrics=expected_metrics)
+
def test_eval_with_weights(self):
n_classes = 2
head = head_lib.multi_label_head(n_classes, weight_column='example_weights')
diff --git a/tensorflow/contrib/estimator/python/estimator/logit_fns.py b/tensorflow/contrib/estimator/python/estimator/logit_fns.py
index 09c2862ccd..c8b0dd6297 100644
--- a/tensorflow/contrib/estimator/python/estimator/logit_fns.py
+++ b/tensorflow/contrib/estimator/python/estimator/logit_fns.py
@@ -41,10 +41,10 @@ from __future__ import print_function
import six
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.canned import dnn as dnn_core
from tensorflow.python.estimator.canned import linear as linear_core
from tensorflow.python.framework import ops
+from tensorflow.python.util import function_utils
# pylint: disable=protected-access
dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder
@@ -72,7 +72,7 @@ def call_logit_fn(logit_fn, features, mode, params, config):
ValueError: if logit_fn does not return a Tensor or a dictionary mapping
strings to Tensors.
"""
- logit_fn_args = util.fn_args(logit_fn)
+ logit_fn_args = function_utils.fn_args(logit_fn)
kwargs = {}
if 'mode' in logit_fn_args:
kwargs['mode'] = mode
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
index f8564446e5..cda23aa437 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
@@ -32,7 +32,6 @@ import six
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.client import device_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import device as framework_device
from tensorflow.python.framework import ops as ops_lib
@@ -48,6 +47,7 @@ from tensorflow.python.platform import tf_logging
from tensorflow.python.training import device_setter as device_setter_lib
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.util import deprecation
+from tensorflow.python.util import function_utils
@deprecation.deprecated(
@@ -521,7 +521,7 @@ def _get_loss_towers(model_fn,
"""Replicate the loss computation across devices."""
tower_specs = []
- model_fn_args = util.fn_args(model_fn)
+ model_fn_args = function_utils.fn_args(model_fn)
optional_params = {}
if 'params' in model_fn_args:
optional_params['params'] = copy.deepcopy(params)
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
index 2889e93743..9f5fee4542 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
@@ -570,7 +570,7 @@ class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest):
'predicted_distributions': self._predicted_distributions,
}
self._expected_loss = 1.61610
- self._expected_op_name = 'mutual_information_loss/mul'
+ self._expected_op_name = 'mutual_information_loss/mul_1'
self._batch_size = 2
diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc
index 60281951dd..66939fbb0f 100644
--- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc
+++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc
@@ -115,7 +115,7 @@ static void CheckOpsSupport(const GraphDef& graph_def,
HexagonOpsDefinitions::getInstance();
LOG(INFO) << "Checking " << graph_def.node_size() << " nodes";
LOG(INFO) << "dump_all_nodes = " << dump_all_nodes
- << ", dump_shape_and_tpye = " << dump_shape_and_type;
+ << ", dump_shape_and_type = " << dump_shape_and_type;
std::unordered_set<string> unsupported_ops;
bool all_supported = true;
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index b01fd5d5c9..56e9194ceb 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1333,7 +1333,7 @@ class DropoutTest(test.TestCase):
with self.test_session():
images = np.random.uniform(size=(5, height, width, 3))
output = _layers.dropout(images)
- self.assertEqual(output.op.name, 'Dropout/dropout/mul')
+ self.assertEqual(output.op.name, 'Dropout/dropout_1/mul')
output.get_shape().assert_is_compatible_with(
ops.convert_to_tensor(images).get_shape())
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 4a360711f8..0fdbe8f630 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -434,6 +434,7 @@ py_test(
name = "kmeans_test",
size = "medium",
srcs = ["python/learn/estimators/kmeans_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
tags = [
"noasan", # b/73741358
@@ -745,7 +746,7 @@ py_test(
tf_py_test(
name = "graph_io_test",
- size = "small",
+ size = "medium",
srcs = ["python/learn/learn_io/graph_io_test.py"],
additional_deps = [
":learn",
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index e28e6854a5..339c4e0e36 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -1862,12 +1862,12 @@ def _get_arguments(func):
if hasattr(func, "__code__"):
# Regular function.
return tf_inspect.getargspec(func)
- elif hasattr(func, "__call__"):
- # Callable object.
- return _get_arguments(func.__call__)
elif hasattr(func, "func"):
# Partial function.
return _get_arguments(func.func)
+ elif hasattr(func, "__call__"):
+ # Callable object.
+ return _get_arguments(func.__call__)
def _verify_loss_fn_args(loss_fn):
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index dfc6a393d0..541da90617 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -38,19 +38,19 @@ from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.python.estimator import estimator as core_estimator
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
__all__ = ["Experiment"]
def _get_standardized_predicate_fn(predicate_fn):
- pred_fn_args = estimator_util.fn_args(predicate_fn)
+ pred_fn_args = function_utils.fn_args(predicate_fn)
if "checkpoint_path" not in pred_fn_args:
# pylint: disable=unused-argument
def _pred_fn_wrapper(eval_results, checkpoint_path):
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index 10065e894c..55b984f260 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -6,8 +6,6 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops")
-exports_files(["LICENSE"])
-
exports_files(glob([
"testdata/*.bin",
"testdata/*.pb",
@@ -114,6 +112,7 @@ cc_library(
"interpreter.cc",
"model.cc",
"nnapi_delegate.cc",
+ "op_resolver.cc",
"optional_debug_tools.cc",
],
hdrs = [
@@ -124,6 +123,7 @@ cc_library(
"interpreter.h",
"model.h",
"nnapi_delegate.h",
+ "op_resolver.h",
"optional_debug_tools.h",
],
copts = tflite_copts(),
@@ -226,6 +226,18 @@ cc_test(
],
)
+# Test OpResolver.
+cc_test(
+ name = "op_resolver_test",
+ size = "small",
+ srcs = ["op_resolver_test.cc"],
+ deps = [
+ ":framework",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
# Test the C extension API code.
cc_test(
name = "context_test",
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index a038acf284..1d0ad2d2db 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -90,6 +90,8 @@ typedef enum {
kTfLiteBuiltinGreaterEqual = 62,
kTfLiteBuiltinLessEqual = 63,
kTfLiteBuiltinSelect = 64,
+ kTfLiteBuiltinSlice = 65,
+ kTfLiteBuiltinSin = 66,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index 12841d233c..4eb66cc225 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -370,13 +370,21 @@ typedef struct _TfLiteRegistration {
// Builtin codes. If this kernel refers to a builtin this is the code
// of the builtin. This is so we can do marshaling to other frameworks like
- // NN API. Note, it is the responsibility of the registration binder to
- // set this properly.
+ // NN API.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
int32_t builtin_code;
// Custom op name. If the op is a builtin, this will be null.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
// WARNING: This is an experimental interface that is subject to change.
const char* custom_name;
+
+ // The version of the op.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ int version;
} TfLiteRegistration;
// WARNING: This is an experimental interface that is subject to change.
diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm
index d74e275f04..59b575ab6e 100644
--- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm
+++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm
@@ -25,8 +25,8 @@
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
#include "tensorflow/contrib/lite/string_util.h"
-#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
#define LOG(x) std::cerr
diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
index 0ab7aa25d0..32da7f7e4f 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
+++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
@@ -24,8 +24,8 @@
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
#include "tensorflow/contrib/lite/string_util.h"
-#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
#include "ios_image_load.h"
diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
index 2a64c1de72..e36218e4f1 100644
--- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
+++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
@@ -62,8 +62,8 @@ void resize(T* out, uint8_t* in, int image_height, int image_width,
{1, wanted_height, wanted_width, wanted_channels}, quant);
ops::builtin::BuiltinOpResolver resolver;
- TfLiteRegistration* resize_op =
- resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR);
+ const TfLiteRegistration* resize_op =
+ resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR, 1);
auto* params = reinterpret_cast<TfLiteResizeBilinearParams*>(
malloc(sizeof(TfLiteResizeBilinearParams)));
params->align_corners = false;
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc
index 456c5c6dc7..966fcd2a31 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.cc
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc
@@ -77,14 +77,13 @@ void PrintProfilingInfo(const profiling::ProfileEvent* e, uint32_t op_index,
// time (ms) , Node xxx, OpCode xxx, symblic name
// 5.352, Node 5, OpCode 4, DEPTHWISE_CONV_2D
-
LOG(INFO) << std::fixed << std::setw(10) << std::setprecision(3)
<< (e->end_timestamp_us - e->begin_timestamp_us) / 1000.0
<< ", Node " << std::setw(3) << std::setprecision(3) << op_index
<< ", OpCode " << std::setw(3) << std::setprecision(3)
<< registration.builtin_code << ", "
<< EnumNameBuiltinOperator(
- (BuiltinOperator)registration.builtin_code)
+ static_cast<BuiltinOperator>(registration.builtin_code))
<< "\n";
}
@@ -190,13 +189,13 @@ void RunInference(Settings* s) {
if (s->profiling) profiler->StartProfiling();
struct timeval start_time, stop_time;
- gettimeofday(&start_time, NULL);
+ gettimeofday(&start_time, nullptr);
for (int i = 0; i < s->loop_count; i++) {
if (interpreter->Invoke() != kTfLiteOk) {
LOG(FATAL) << "Failed to invoke tflite!\n";
}
}
- gettimeofday(&stop_time, NULL);
+ gettimeofday(&stop_time, nullptr);
LOG(INFO) << "invoked \n";
LOG(INFO) << "average time: "
<< (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)
@@ -271,17 +270,17 @@ int Main(int argc, char** argv) {
int c;
while (1) {
static struct option long_options[] = {
- {"accelerated", required_argument, 0, 'a'},
- {"count", required_argument, 0, 'c'},
- {"verbose", required_argument, 0, 'v'},
- {"image", required_argument, 0, 'i'},
- {"labels", required_argument, 0, 'l'},
- {"tflite_model", required_argument, 0, 'm'},
- {"profiling", required_argument, 0, 'p'},
- {"threads", required_argument, 0, 't'},
- {"input_mean", required_argument, 0, 'b'},
- {"input_std", required_argument, 0, 's'},
- {0, 0, 0, 0}};
+ {"accelerated", required_argument, nullptr, 'a'},
+ {"count", required_argument, nullptr, 'c'},
+ {"verbose", required_argument, nullptr, 'v'},
+ {"image", required_argument, nullptr, 'i'},
+ {"labels", required_argument, nullptr, 'l'},
+ {"tflite_model", required_argument, nullptr, 'm'},
+ {"profiling", required_argument, nullptr, 'p'},
+ {"threads", required_argument, nullptr, 't'},
+ {"input_mean", required_argument, nullptr, 'b'},
+ {"input_std", required_argument, nullptr, 's'},
+ {nullptr, 0, nullptr, 0}};
/* getopt_long stores the option index here. */
int option_index = 0;
@@ -294,15 +293,14 @@ int Main(int argc, char** argv) {
switch (c) {
case 'a':
- s.accel = strtol( // NOLINT(runtime/deprecated_fn)
- optarg, (char**)NULL, 10);
+ s.accel = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
break;
case 'b':
- s.input_mean = strtod(optarg, NULL);
+ s.input_mean = strtod(optarg, nullptr);
break;
case 'c':
- s.loop_count = strtol( // NOLINT(runtime/deprecated_fn)
- optarg, (char**)NULL, 10);
+ s.loop_count =
+ strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
break;
case 'i':
s.input_bmp_name = optarg;
@@ -314,19 +312,19 @@ int Main(int argc, char** argv) {
s.model_name = optarg;
break;
case 'p':
- s.profiling = strtol( // NOLINT(runtime/deprecated_fn)
- optarg, (char**)NULL, 10);
+ s.profiling =
+ strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
break;
case 's':
- s.input_std = strtod(optarg, NULL);
+ s.input_std = strtod(optarg, nullptr);
break;
case 't':
s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn)
- optarg, (char**)NULL, 10);
+ optarg, nullptr, 10);
break;
case 'v':
- s.verbose = strtol( // NOLINT(runtime/deprecated_fn)
- optarg, (char**)NULL, 10);
+ s.verbose =
+ strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
break;
case 'h':
case '?':
diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md
index d7cc854eba..972e57f73e 100644
--- a/tensorflow/contrib/lite/g3doc/custom_operators.md
+++ b/tensorflow/contrib/lite/g3doc/custom_operators.md
@@ -39,7 +39,7 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
int num_dims = NumDimensions(input);
@@ -54,7 +54,7 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
using namespace tflite;
- TfLiteTensor* input = GetInput(context, node,0);
+ const TfLiteTensor* input = GetInput(context, node,0);
TfLiteTensor* output = GetOutput(context, node,0);
float* input_data = input->data.f;
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index f45fcceb2e..f52d0fb08f 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -134,7 +134,6 @@ following common ops are not supported at the moment:
* [tf.depth_to_space](https://www.tensorflow.org/api_docs/python/tf/depth_to_space)
* [tf.gather](https://www.tensorflow.org/api_docs/python/tf/gather)
* [tf.image.resize_bilinear](https://www.tensorflow.org/api_docs/python/tf/image/resize_bilinear)
-* [tf.slice](https://www.tensorflow.org/api_docs/python/tf/slice)
* [tf.tanh](https://www.tensorflow.org/api_docs/python/tf/tanh)
## TensorFlow Lite Operations
@@ -523,6 +522,19 @@ Options {
}
```
+**SLICE**
+
+```
+Inputs {
+ 0: tensor
+ 1: 1D tensor
+ 2: 1D tensor
+}
+Outputs {
+ 0: slice of the input tensor of the given size from the given begin index.
+}
+```
+
**SOFTMAX**
```
@@ -608,7 +620,7 @@ Outputs {
0: slice of the input tensor of the given size
}
Options {
- begin_mask: mask for begin indicies
+ begin_mask: mask for begin indices
end_mask: mask for end indices
shrink_axis_mask: mask that indicates which dimensions to remove
}
@@ -623,7 +635,7 @@ Inputs {
}
Outputs {
0: k largest element along each last dimensional slice
- 1: indicies of values within the last dimension of the input ensor
+ 1: indices of values within the last dimension of the input ensor
}
```
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml
index ba63dce5d9..95b6b7016f 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml
@@ -31,6 +31,7 @@
android:theme="@style/MaterialTheme">
<activity android:name="com.example.android.tflitecamerademo.CameraActivity"
+ android:screenOrientation="portrait"
android:label="@string/app_name">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml
index 72a229ecdb..ddb099a950 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml
@@ -28,7 +28,7 @@
<LinearLayout
android:layout_width="wrap_content"
android:layout_height="wrap_content"
- android:layout_alignParentBottom="true"
+ android:layout_above="@+id/bottom_info_view"
android:layout_alignParentEnd="false"
android:layout_alignParentStart="true"
android:layout_alignParentTop="false"
@@ -57,32 +57,39 @@
android:textStyle="bold" />
</LinearLayout>
+ <LinearLayout
+ android:orientation="horizontal"
+ android:background="#513400"
+ android:layout_alignParentBottom="true"
- <RelativeLayout
- android:id="@+id/control2"
android:layout_width="match_parent"
- android:layout_height="135dp"
- android:layout_alignParentLeft="true"
- android:layout_alignParentStart="true"
- android:layout_alignTop="@+id/control"
- android:layout_marginLeft="300dp"
- android:layout_marginStart="300dp"
- android:background="#bb7700">
-
+ android:id="@+id/bottom_info_view"
+ android:layout_marginBottom="10dp"
+ android:layout_height="50dp">
+ <TextView
+ android:layout_width="wrap_content"
+ android:layout_height="match_parent"
+ android:textColor="@android:color/white"
+ android:textAlignment="center"
+ android:gravity="center"
+ android:text="Threads:"/>
+ <NumberPicker
+ android:id="@+id/np"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:layout_marginLeft="10dp"
+ android:theme="@style/AppTheme.Picker"
+ android:visibility="visible" />
<ToggleButton
android:id="@+id/button"
android:textOff="@string/tflite"
android:textOn="@string/nnapi"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
- android:layout_alignParentLeft="true"
- android:layout_alignParentStart="true" />
+ android:layout_marginLeft="10dp"
+ android:background="#0000000f"
+ android:textColor="@android:color/white" />
+ </LinearLayout>
+
- <NumberPicker
- android:id="@+id/np"
- android:layout_width="wrap_content"
- android:layout_height="wrap_content"
- android:layout_below="@+id/button"
- android:visibility="visible" />
- </RelativeLayout>
</RelativeLayout>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml
index 72a229ecdb..e567009a42 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml
@@ -28,7 +28,7 @@
<LinearLayout
android:layout_width="wrap_content"
android:layout_height="wrap_content"
- android:layout_alignParentBottom="true"
+ android:layout_above="@+id/bottom_info_view"
android:layout_alignParentEnd="false"
android:layout_alignParentStart="true"
android:layout_alignParentTop="false"
@@ -57,32 +57,38 @@
android:textStyle="bold" />
</LinearLayout>
+ <LinearLayout
+ android:orientation="horizontal"
+ android:background="#aa7700"
+ android:layout_alignParentBottom="true"
- <RelativeLayout
- android:id="@+id/control2"
android:layout_width="match_parent"
- android:layout_height="135dp"
- android:layout_alignParentLeft="true"
- android:layout_alignParentStart="true"
- android:layout_alignTop="@+id/control"
- android:layout_marginLeft="300dp"
- android:layout_marginStart="300dp"
- android:background="#bb7700">
-
+ android:id="@+id/bottom_info_view"
+ android:layout_marginBottom="10dp"
+ android:layout_height="50dp">
+ <TextView
+ android:layout_width="wrap_content"
+ android:layout_height="match_parent"
+ android:textColor="@android:color/white"
+ android:textAlignment="center"
+ android:gravity="center"
+ android:text="@string/threads" />
+ <NumberPicker
+ android:id="@+id/np"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:layout_marginLeft="10dp"
+ android:theme="@style/AppTheme.Picker"
+ android:visibility="visible" />
<ToggleButton
android:id="@+id/button"
android:textOff="@string/tflite"
android:textOn="@string/nnapi"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
- android:layout_alignParentLeft="true"
- android:layout_alignParentStart="true" />
+ android:layout_marginLeft="10dp"
+ android:background="#0000000f"
+ android:textColor="@android:color/white" />
- <NumberPicker
- android:id="@+id/np"
- android:layout_width="wrap_content"
- android:layout_height="wrap_content"
- android:layout_below="@+id/button"
- android:visibility="visible" />
- </RelativeLayout>
+ </LinearLayout>
</RelativeLayout>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml
index 0a71dbd0e8..7af8f3a98c 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml
@@ -16,7 +16,7 @@
-->
<resources>
- <string name="app_name">TfLiteCameraDemo</string>
+ <string name="app_name">TfLite Camera Demo</string>
<string name="intro_message">
<![CDATA[
@@ -27,4 +27,5 @@
]]>
</string>
+ <string name="threads">Threads:</string>
</resources>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml
index 3f3bdfb494..1752b3b5f9 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml
@@ -14,5 +14,10 @@
limitations under the License.
-->
<resources>
- <style name="MaterialTheme" parent="android:Theme.Material.Light.NoActionBar.Fullscreen" />
+ <style name="MaterialTheme" parent="android:Theme.Material.Light.NoActionBar.Fullscreen" />
+ <style name="AppTheme.Picker" parent="android:Theme.Material.Light.NoActionBar.Fullscreen" >
+ <item name="android:textColorPrimary">@android:color/white</item>
+
+</style>
+
</resources>
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 79e3c9f266..6e2e790517 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -143,6 +143,7 @@ cc_library(
"depthwise_conv.cc",
"dequantize.cc",
"div.cc",
+ "elementwise.cc",
"embedding_lookup.cc",
"embedding_lookup_sparse.cc",
"exp.cc",
@@ -166,6 +167,7 @@ cc_library(
"resize_bilinear.cc",
"select.cc",
"skip_gram.cc",
+ "slice.cc",
"space_to_batch_nd.cc",
"space_to_depth.cc",
"split.cc",
@@ -455,6 +457,19 @@ tf_cc_test(
)
tf_cc_test(
+ name = "elementwise_test",
+ size = "small",
+ srcs = ["elementwise_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
name = "unidirectional_sequence_lstm_test",
size = "small",
srcs = ["unidirectional_sequence_lstm_test.cc"],
@@ -888,6 +903,23 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "slice_test",
+ size = "small",
+ srcs = [
+ "slice_test.cc",
+ ],
+ tags = [
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index 39a54c9396..4972159a05 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -55,7 +55,7 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
@@ -68,7 +68,7 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
@@ -95,7 +95,7 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
@@ -126,7 +126,7 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
@@ -153,9 +153,9 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
- TfLiteTensor* alpha = GetInput(context, node, 1);
+ const TfLiteTensor* alpha = GetInput(context, node, 1);
output->type = input->type;
@@ -179,7 +179,7 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
case kTfLiteFloat32: {
@@ -197,7 +197,7 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
case kTfLiteFloat32: {
@@ -217,7 +217,7 @@ TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
case kTfLiteFloat32: {
@@ -236,7 +236,7 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
case kTfLiteFloat32: {
@@ -265,7 +265,7 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
case kTfLiteFloat32: {
@@ -292,7 +292,7 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
}
// Takes a 2D tensor and perform softmax along the second dimension.
-void Softmax2DFloat(TfLiteTensor* input, TfLiteTensor* output,
+void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) {
const int batch_size = input->dims->data[0];
const int input_size = input->dims->data[1];
@@ -327,7 +327,7 @@ void Softmax2DFloat(TfLiteTensor* input, TfLiteTensor* output,
}
}
-void Softmax2DQuantized(TfLiteTensor* input, TfLiteTensor* output,
+void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
// TODO(ahentz): this is arguably a dirty trick. Since the implementation
// always traverses the last dimension of a 4D tensor, we will pretend our 2D
@@ -343,14 +343,14 @@ void Softmax2DQuantized(TfLiteTensor* input, TfLiteTensor* output,
}
// Takes a 4D tensor and perform softmax along the forth dimension.
-void Softmax4DFloat(TfLiteTensor* input, TfLiteTensor* output,
+void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) {
optimized_ops::Softmax(GetTensorData<float>(input), GetTensorDims(input),
params->beta, GetTensorData<float>(output),
GetTensorDims(output));
}
-void Softmax4DQuantized(TfLiteTensor* input, TfLiteTensor* output,
+void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
optimized_ops::Softmax(GetTensorData<uint8_t>(input), GetTensorDims(input),
data->input_multiplier, data->input_left_shift,
@@ -362,7 +362,7 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
// TODO(ahentz): consider an implementation that works for many (all?)
@@ -402,7 +402,7 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
case kTfLiteFloat32:
@@ -417,9 +417,9 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = GetInput(context, node, 0);
- TfLiteTensor* alpha = GetInput(context, node, 1);
- TfLiteTensor* output = GetOutput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* alpha = GetInput(context, node, 1);
+ const TfLiteTensor* output = GetOutput(context, node, 0);
if (input->type != kTfLiteFloat32) {
context->ReportError(context, "Only float32 supported currently.");
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
index e0aa070e2d..7ca1e35489 100644
--- a/tensorflow/contrib/lite/kernels/add.cc
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -57,8 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
@@ -80,7 +80,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
void EvalAddFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteAddParams* params, const OpData* data,
- TfLiteTensor* input1, TfLiteTensor* input2,
+ const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRangeFloat(params->activation, &output_activation_min,
@@ -109,7 +109,7 @@ void EvalAddFloat(TfLiteContext* context, TfLiteNode* node,
template <KernelType kernel_type>
void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteAddParams* params, const OpData* data,
- TfLiteTensor* input1, TfLiteTensor* input2,
+ const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output) {
auto input1_offset = -input1->params.zero_point;
auto input2_offset = -input2->params.zero_point;
@@ -164,8 +164,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
- TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
diff --git a/tensorflow/contrib/lite/kernels/arg_max.cc b/tensorflow/contrib/lite/kernels/arg_max.cc
index a2c5e4cead..566d37047a 100644
--- a/tensorflow/contrib/lite/kernels/arg_max.cc
+++ b/tensorflow/contrib/lite/kernels/arg_max.cc
@@ -33,8 +33,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* axis = GetInput(context, node, kAxis);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* axis = GetInput(context, node, kAxis);
// Make sure the axis is only 1 dimension.
TF_LITE_ENSURE_EQ(context, NumElements(axis), 1);
@@ -79,8 +79,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// The current impl actually ignores the axis argument.
// Only determine the index of the maximum value in the last dimension.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* axis = GetInput(context, node, kAxis);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* axis = GetInput(context, node, kAxis);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
#define TF_LITE_ARG_MAX(data_type, axis_type, output_type) \
diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
index 602f3888c1..91d8dd3fa7 100644
--- a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
+++ b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
@@ -72,7 +72,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
@@ -102,7 +102,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteAudioSpectrogramParams*>(node->user_data);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size,
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc
index 2c5074eca3..0907547f9f 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc
@@ -12,18 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unistd.h>
-#include <cassert>
-#include <cmath>
-#include <cstdio>
-#include <cstdlib>
-#include <iostream>
-#include <limits>
+#include <stddef.h>
+#include <stdint.h>
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -35,20 +31,29 @@ constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kRecurrentWeightsTensor = 2;
constexpr int kBiasTensor = 3;
-constexpr int KHiddenStateTensor = 0;
+constexpr int kHiddenStateTensor = 0;
constexpr int kOutputTensor = 1;
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
- TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
- TfLiteTensor* input_weights =
- &context->tensors[node->inputs->data[kWeightsTensor]];
- TfLiteTensor* recurrent_weights =
- &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
- TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
+ const TfLiteTensor* recurrent_weights =
+ GetInput(context, node, kRecurrentWeightsTensor);
+ const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@@ -58,10 +63,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]);
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]);
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type);
- TfLiteTensor* hidden_state =
- &context->tensors[node->outputs->data[KHiddenStateTensor]];
- TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
+ TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Resize state.
TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
@@ -80,25 +86,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size_array));
+ // Allocate temporary tensors to store quantized values of input and
+ // hidden_state tensors.
+ if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(2);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+ node->temporaries->data[1] = *scratch_tensor_index + 1;
+ TfLiteTensor* hidden_state_quantized =
+ GetTemporary(context, node, /*index=*/1);
+ hidden_state_quantized->type = kTfLiteUInt8;
+ hidden_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(hidden_state_quantized->dims,
+ hidden_state->dims)) {
+ TfLiteIntArray* hidden_state_quantized_size =
+ TfLiteIntArrayCopy(hidden_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, hidden_state_quantized,
+ hidden_state_quantized_size));
+ }
+ }
+
return kTfLiteOk;
}
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
-
- TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
- TfLiteTensor* input_weights =
- &context->tensors[node->inputs->data[kWeightsTensor]];
- TfLiteTensor* recurrent_weights =
- &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
- TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
- TfLiteTensor* hidden_state =
- &context->tensors[node->outputs->data[KHiddenStateTensor]];
- TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
-
- // Initialize the pointer bias.
- const float* bias_ptr = bias->data.f;
-
+TfLiteStatus EvalFloat(const TfLiteTensor* input,
+ const TfLiteTensor* input_weights,
+ const TfLiteTensor* recurrent_weights,
+ const TfLiteTensor* bias, const TfLiteRNNParams* params,
+ TfLiteTensor* hidden_state, TfLiteTensor* output) {
const int batch_size = input->dims->data[0];
const int num_units = input_weights->dims->data[0];
const int input_size = input->dims->data[1];
@@ -108,9 +133,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Initialize the pointer to input and output.
const float* input_ptr_batch = input->data.f;
float* output_ptr_batch = output->data.f;
- // Initialize input_weights and recurrent_weights.
+ // Initialize input_weights, recurrent_weights and bias.
const float* input_weights_ptr = input_weights->data.f;
const float* recurrent_weights_ptr = recurrent_weights->data.f;
+ const float* bias_ptr = bias->data.f;
kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr,
recurrent_weights_ptr, bias_ptr, input_size,
@@ -119,11 +145,81 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+TfLiteStatus EvalQuantized(const TfLiteTensor* input,
+ const TfLiteTensor* input_weights,
+ const TfLiteTensor* recurrent_weights,
+ const TfLiteTensor* bias,
+ const TfLiteRNNParams* params,
+ TfLiteTensor* input_scratch,
+ TfLiteTensor* hidden_state_scratch,
+ TfLiteTensor* hidden_state, TfLiteTensor* output) {
+ const int batch_size = input->dims->data[0];
+ const int num_units = input_weights->dims->data[0];
+ const int input_size = input->dims->data[1];
+
+ // Initialize the pointer to hidden state.
+ float* hidden_state_ptr_batch = hidden_state->data.f;
+ // Initialize the pointer to input and output.
+ const float* input_ptr_batch = input->data.f;
+ float* output_ptr_batch = output->data.f;
+ // Initialize input_weights, recurrent_weights and bias.
+ const int8_t* input_weights_ptr =
+ reinterpret_cast<const int8_t*>(input_weights->data.uint8);
+ const int8_t* recurrent_weights_ptr =
+ reinterpret_cast<const int8_t*>(recurrent_weights->data.uint8);
+ const float* bias_ptr = bias->data.f;
+ // Get the scale of the quantized weights.
+ float input_weights_scale = input_weights->params.scale;
+ float recurrent_weights_scale = recurrent_weights->params.scale;
+ // Initialize temporary storage for quantized values.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_scratch->data.uint8);
+ int8_t* quantized_hidden_state_ptr =
+ reinterpret_cast<int8_t*>(hidden_state_scratch->data.uint8);
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, input_weights_ptr, input_weights_scale,
+ recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
+ num_units, batch_size, params->activation, quantized_input_ptr,
+ quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
+ const TfLiteTensor* recurrent_weights =
+ GetInput(context, node, kRecurrentWeightsTensor);
+ const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // We already checked that weight types are consistent, so branch on one.
+ switch (input_weights->type) {
+ case kTfLiteFloat32:
+ return EvalFloat(input, input_weights, recurrent_weights, bias, params,
+ hidden_state, output);
+ case kTfLiteUInt8: {
+ // TODO(mirkov): implement eval with quantized inputs as well.
+ TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
+ TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
+ return EvalQuantized(input, input_weights, recurrent_weights, bias,
+ params, input_quantized, hidden_state_quantized,
+ hidden_state, output);
+ }
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
} // namespace rnn
TfLiteRegistration* Register_RNN() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- rnn::Prepare, rnn::Eval};
+ static TfLiteRegistration r = {rnn::Init, rnn::Free, rnn::Prepare, rnn::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
index fa7ef525db..96465fcaf0 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
@@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/
// Unit test for TFLite RNN op.
-#include <iomanip>
+#include <string.h>
+#include <initializer_list>
+#include <memory>
#include <vector>
#include <gmock/gmock.h>
@@ -122,13 +124,62 @@ static float rnn_golden_output[] = {
0, 2.02616, 0, 0.728256, 0.84183, 0.0907453,
0.628881, 3.58099, 1.49974, 0};
+static std::initializer_list<float> rnn_weights = {
+ 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
+ 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
+ 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
+ -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
+ -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
+ -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
+ -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
+ 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
+ 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
+ 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
+ -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
+ 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
+ -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
+ -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
+ 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
+ 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
+ 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
+ -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
+ 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
+ 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
+ -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
+ 0.277308, 0.415818};
+
+static std::initializer_list<float> rnn_recurrent_weights = {
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1};
+
+static std::initializer_list<float> rnn_bias = {
+ 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568,
+ -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178,
+ 0.37197268, 0.61957061, 0.3956964, -0.37609905};
+
class RNNOpModel : public SingleOpModel {
public:
- RNNOpModel(int batches, int units, int size)
+ RNNOpModel(int batches, int units, int size,
+ const TensorType& weights = TensorType_FLOAT32,
+ const TensorType& recurrent_weights = TensorType_FLOAT32)
: batches_(batches), units_(units), input_size_(size) {
input_ = AddInput(TensorType_FLOAT32);
- weights_ = AddInput(TensorType_FLOAT32);
- recurrent_weights_ = AddInput(TensorType_FLOAT32);
+ weights_ = AddInput(weights);
+ recurrent_weights_ = AddInput(recurrent_weights);
bias_ = AddInput(TensorType_FLOAT32);
hidden_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -173,7 +224,7 @@ class RNNOpModel : public SingleOpModel {
int num_units() { return units_; }
int num_batches() { return batches_; }
- private:
+ protected:
int input_;
int weights_;
int recurrent_weights_;
@@ -186,53 +237,26 @@ class RNNOpModel : public SingleOpModel {
int input_size_;
};
-TEST(FullyConnectedOpTest, BlackBoxTest) {
+// The hybrid model has quantized weights and recurrent_weights.
+class HybridRNNOpModel : public RNNOpModel {
+ public:
+ HybridRNNOpModel(int batches, int units, int size)
+ : RNNOpModel(batches, units, size, TensorType_UINT8, TensorType_UINT8) {}
+
+ void SetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(weights_, f);
+ }
+
+ void SetRecurrentWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_weights_, f);
+ }
+};
+
+TEST(RnnOpTest, BlackBoxTest) {
RNNOpModel rnn(2, 16, 8);
- rnn.SetWeights(
- {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
- 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
- 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
- -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
- -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
- -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
- -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
- 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
- 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
- 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
- -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
- 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
- -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
- -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
- 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
- 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
- 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
- -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
- 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
- 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
- -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
- 0.277308, 0.415818});
-
- rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
- -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
- 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
- -0.37609905});
-
- rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1});
+ rnn.SetWeights(rnn_weights);
+ rnn.SetBias(rnn_bias);
+ rnn.SetRecurrentWeights(rnn_recurrent_weights);
rnn.ResetHiddenState();
const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
@@ -256,6 +280,35 @@ TEST(FullyConnectedOpTest, BlackBoxTest) {
}
}
+TEST(HybridRnnOpTest, BlackBoxTest) {
+ HybridRNNOpModel rnn(2, 16, 8);
+ rnn.SetWeights(rnn_weights);
+ rnn.SetBias(rnn_bias);
+ rnn.SetRecurrentWeights(rnn_recurrent_weights);
+
+ rnn.ResetHiddenState();
+ const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
+ (rnn.input_size() * rnn.num_batches());
+
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch_start = rnn_input + i * rnn.input_size();
+ float* batch_end = batch_start + rnn.input_size();
+ rnn.SetInput(0, batch_start, batch_end);
+ rnn.SetInput(rnn.input_size(), batch_start, batch_end);
+
+ rnn.Invoke();
+
+ float* golden_start = rnn_golden_output + i * rnn.num_units();
+ float* golden_end = golden_start + rnn.num_units();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ expected, /*max_abs_error=*/0.0104)));
+ }
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
index 90edf4f9e3..262e1aeab1 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
@@ -40,9 +40,9 @@ struct BatchToSpaceNDContext {
crops = GetInput(context, node, 2);
output = GetOutput(context, node, 0);
}
- TfLiteTensor* input;
- TfLiteTensor* block_shape;
- TfLiteTensor* crops;
+ const TfLiteTensor* input;
+ const TfLiteTensor* block_shape;
+ const TfLiteTensor* crops;
TfLiteTensor* output;
};
@@ -66,12 +66,10 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops),
kSpatialDimensionNum);
- // TODO(ycling): Add crops as part of calculation. Remove check for a crops
- // containing all zeroes.
- TF_LITE_ENSURE_EQ(context, crops[0], 0);
- TF_LITE_ENSURE_EQ(context, crops[1], 0);
- TF_LITE_ENSURE_EQ(context, crops[2], 0);
- TF_LITE_ENSURE_EQ(context, crops[3], 0);
+ TF_LITE_ENSURE(context, crops[0] >= 0);
+ TF_LITE_ENSURE(context, crops[1] >= 0);
+ TF_LITE_ENSURE(context, crops[2] >= 0);
+ TF_LITE_ENSURE(context, crops[3] >= 0);
// Number of batch must be multiple of (block_shape[0] * block_shape[1]).
TF_LITE_ENSURE_EQ(context,
@@ -79,8 +77,16 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
const int output_batch_size =
input_size->data[0] / (block_shape[0] * block_shape[1]);
- const int output_height = input_size->data[1] * block_shape[0];
- const int output_width = input_size->data[2] * block_shape[1];
+
+ const int crops_top = crops[0];
+ const int crops_bottom = crops[1];
+ const int crops_left = crops[2];
+ const int crops_right = crops[3];
+ const int output_height =
+ input_size->data[1] * block_shape[0] - crops_top - crops_bottom;
+ const int output_width =
+ input_size->data[2] * block_shape[1] - crops_left - crops_right;
+
const int output_channel_size = input_size->data[3];
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc
index 8485cde1b4..95b025c1b3 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc
@@ -120,16 +120,16 @@ TEST(BatchToSpaceNDOpTest, InvalidShapeTest) {
}
TEST(BatchToSpaceNDOpTest, InvalidCropsConstTest) {
- EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 1}),
- "1 != 0");
+ EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, -1}),
+ "crops.3. >= 0 was not true.");
}
TEST(BatchToSpaceNDOpTest, InvalidCropsDynamicTest) {
BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.SetBlockShape({2, 2});
- m.SetCrops({0, 0, 1, 0});
- EXPECT_DEATH(m.Invoke(), "1 != 0");
+ m.SetCrops({0, 0, -1, 0});
+ EXPECT_DEATH(m.Invoke(), "crops.2. >= 0 was not true.");
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index a35ba23ced..1cd4884696 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -143,13 +143,13 @@ TfLiteStatus CheckLstmTensorDimensions(
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
}
- TfLiteTensor* input_to_forget_weights =
+ const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, input_to_forget_weights_tensor);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
- TfLiteTensor* input_to_cell_weights =
+ const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, input_to_cell_weights_tensor);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
@@ -165,7 +165,7 @@ TfLiteStatus CheckLstmTensorDimensions(
n_output);
}
- TfLiteTensor* recurrent_to_forget_weights =
+ const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, recurrent_to_forget_weights_tensor);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
@@ -173,7 +173,7 @@ TfLiteStatus CheckLstmTensorDimensions(
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
n_output);
- TfLiteTensor* recurrent_to_cell_weights =
+ const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, recurrent_to_cell_weights_tensor);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
@@ -231,16 +231,17 @@ TfLiteStatus CheckLstmTensorDimensions(
TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
}
- TfLiteTensor* forget_gate_bias =
+ const TfLiteTensor* forget_gate_bias =
GetInput(context, node, forget_gate_bias_tensor);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
- TfLiteTensor* cell_bias = GetInput(context, node, cell_gate_bias_tensor);
+ const TfLiteTensor* cell_bias =
+ GetInput(context, node, cell_gate_bias_tensor);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
- TfLiteTensor* output_gate_bias =
+ const TfLiteTensor* output_gate_bias =
GetInput(context, node, output_gate_bias_tensor);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
@@ -312,20 +313,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input->dims->size > 1);
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
- TfLiteTensor* fw_input_to_output_weights =
+ const TfLiteTensor* fw_input_to_output_weights =
GetInput(context, node, kFwInputToOutputWeightsTensor);
const int n_fw_cell = fw_input_to_output_weights->dims->data[0];
TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1],
n_input);
- TfLiteTensor* fw_recurrent_to_output_weights =
+ const TfLiteTensor* fw_recurrent_to_output_weights =
GetInput(context, node, kFwRecurrentToOutputWeightsTensor);
TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->data[0],
@@ -388,14 +389,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer,
fw_scratch_buffer_size));
// Same for the backward cell.
- TfLiteTensor* bw_input_to_output_weights =
+ const TfLiteTensor* bw_input_to_output_weights =
GetInput(context, node, kBwInputToOutputWeightsTensor);
const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
n_input);
- TfLiteTensor* bw_recurrent_to_output_weights =
+ const TfLiteTensor* bw_recurrent_to_output_weights =
GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
@@ -463,7 +464,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Input tensor.
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
@@ -471,20 +472,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Tensors for the forward cell.
TfLiteTensor* fw_input_to_input_weights =
GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
- TfLiteTensor* fw_input_to_forget_weights =
+ const TfLiteTensor* fw_input_to_forget_weights =
GetInput(context, node, kFwInputToForgetWeightsTensor);
- TfLiteTensor* fw_input_to_cell_weights =
+ const TfLiteTensor* fw_input_to_cell_weights =
GetInput(context, node, kFwInputToCellWeightsTensor);
- TfLiteTensor* fw_input_to_output_weights =
+ const TfLiteTensor* fw_input_to_output_weights =
GetInput(context, node, kFwInputToOutputWeightsTensor);
TfLiteTensor* fw_recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kFwRecurrentToInputWeightsTensor);
- TfLiteTensor* fw_recurrent_to_forget_weights =
+ const TfLiteTensor* fw_recurrent_to_forget_weights =
GetInput(context, node, kFwRecurrentToForgetWeightsTensor);
- TfLiteTensor* fw_recurrent_to_cell_weights =
+ const TfLiteTensor* fw_recurrent_to_cell_weights =
GetInput(context, node, kFwRecurrentToCellWeightsTensor);
- TfLiteTensor* fw_recurrent_to_output_weights =
+ const TfLiteTensor* fw_recurrent_to_output_weights =
GetInput(context, node, kFwRecurrentToOutputWeightsTensor);
TfLiteTensor* fw_cell_to_input_weights =
@@ -496,10 +497,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* fw_input_gate_bias =
GetOptionalInputTensor(context, node, kFwInputGateBiasTensor);
- TfLiteTensor* fw_forget_gate_bias =
+ const TfLiteTensor* fw_forget_gate_bias =
GetInput(context, node, kFwForgetGateBiasTensor);
- TfLiteTensor* fw_cell_bias = GetInput(context, node, kFwCellGateBiasTensor);
- TfLiteTensor* fw_output_gate_bias =
+ const TfLiteTensor* fw_cell_bias =
+ GetInput(context, node, kFwCellGateBiasTensor);
+ const TfLiteTensor* fw_output_gate_bias =
GetInput(context, node, kFwOutputGateBiasTensor);
TfLiteTensor* fw_projection_weights =
@@ -515,20 +517,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Tensors for the backward cell.
TfLiteTensor* bw_input_to_input_weights =
GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
- TfLiteTensor* bw_input_to_forget_weights =
+ const TfLiteTensor* bw_input_to_forget_weights =
GetInput(context, node, kBwInputToForgetWeightsTensor);
- TfLiteTensor* bw_input_to_cell_weights =
+ const TfLiteTensor* bw_input_to_cell_weights =
GetInput(context, node, kBwInputToCellWeightsTensor);
- TfLiteTensor* bw_input_to_output_weights =
+ const TfLiteTensor* bw_input_to_output_weights =
GetInput(context, node, kBwInputToOutputWeightsTensor);
TfLiteTensor* bw_recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kBwRecurrentToInputWeightsTensor);
- TfLiteTensor* bw_recurrent_to_forget_weights =
+ const TfLiteTensor* bw_recurrent_to_forget_weights =
GetInput(context, node, kBwRecurrentToForgetWeightsTensor);
- TfLiteTensor* bw_recurrent_to_cell_weights =
+ const TfLiteTensor* bw_recurrent_to_cell_weights =
GetInput(context, node, kBwRecurrentToCellWeightsTensor);
- TfLiteTensor* bw_recurrent_to_output_weights =
+ const TfLiteTensor* bw_recurrent_to_output_weights =
GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
TfLiteTensor* bw_cell_to_input_weights =
@@ -540,10 +542,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* bw_input_gate_bias =
GetOptionalInputTensor(context, node, kBwInputGateBiasTensor);
- TfLiteTensor* bw_forget_gate_bias =
+ const TfLiteTensor* bw_forget_gate_bias =
GetInput(context, node, kBwForgetGateBiasTensor);
- TfLiteTensor* bw_cell_bias = GetInput(context, node, kBwCellGateBiasTensor);
- TfLiteTensor* bw_output_gate_bias =
+ const TfLiteTensor* bw_cell_bias =
+ GetInput(context, node, kBwCellGateBiasTensor);
+ const TfLiteTensor* bw_output_gate_bias =
GetInput(context, node, kBwOutputGateBiasTensor);
TfLiteTensor* bw_projection_weights =
diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc
index 17ef2c572e..673eedc2e9 100644
--- a/tensorflow/contrib/lite/kernels/cast.cc
+++ b/tensorflow/contrib/lite/kernels/cast.cc
@@ -32,7 +32,7 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// TODO(ahentz): these two checks would make the new implementation
@@ -77,7 +77,7 @@ TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out,
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const int num_elements = NumElements(input);
TF_LITE_ENSURE_EQ(context, num_elements, NumElements(output));
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 2885ce032b..b948334b6d 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -32,8 +32,8 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Don't support string and bool.
@@ -68,8 +68,8 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
GetTensorData<bool>(output), GetTensorDims(output));
TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
bool requires_broadcast = !HaveSameShapes(input1, input2);
// TODO(renjieliu): Support quantized data.
@@ -92,8 +92,8 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
bool requires_broadcast = !HaveSameShapes(input1, input2);
// TODO(renjieliu): Support quantized data.
@@ -116,8 +116,8 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
bool requires_broadcast = !HaveSameShapes(input1, input2);
// TODO(renjieliu): Support quantized data.
@@ -140,8 +140,8 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
bool requires_broadcast = !HaveSameShapes(input1, input2);
// TODO(renjieliu): Support quantized data.
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index eeda1bc3c5..3ad8d7d4e1 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -83,9 +83,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bool hasBias = NumInputs(node) == 3;
TF_LITE_ENSURE(context, hasBias || NumInputs(node) == 2);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
- TfLiteTensor* bias = nullptr;
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
+ const TfLiteTensor* bias = nullptr;
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
@@ -169,8 +169,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
void EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteDepthwiseConvParams* params, OpData* data,
- TfLiteTensor* input, TfLiteTensor* filter, TfLiteTensor* bias,
- TfLiteTensor* output) {
+ const TfLiteTensor* input, const TfLiteTensor* filter,
+ const TfLiteTensor* bias, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRangeFloat(params->activation, &output_activation_min,
&output_activation_max);
@@ -196,8 +196,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
template <KernelType kernel_type>
void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteDepthwiseConvParams* params, OpData* data,
- TfLiteTensor* input, TfLiteTensor* filter,
- TfLiteTensor* bias, TfLiteTensor* output) {
+ const TfLiteTensor* input, const TfLiteTensor* filter,
+ const TfLiteTensor* bias, TfLiteTensor* output) {
auto input_offset = -input->params.zero_point;
auto filter_offset = -filter->params.zero_point;
auto output_offset = output->params.zero_point;
@@ -230,9 +230,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
- TfLiteTensor* bias =
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
+ const TfLiteTensor* bias =
(NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr;
// TODO(aselle): Consider whether float conv and quantized conv should be
diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc
index e685f2465f..672b2170e4 100644
--- a/tensorflow/contrib/lite/kernels/dequantize.cc
+++ b/tensorflow/contrib/lite/kernels/dequantize.cc
@@ -32,7 +32,7 @@ struct OpContext {
input = GetInput(context, node, 0);
output = GetOutput(context, node, 0);
}
- TfLiteTensor* input;
+ const TfLiteTensor* input;
TfLiteTensor* output;
};
diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc
index ec380c8e49..e52e4fe535 100644
--- a/tensorflow/contrib/lite/kernels/div.cc
+++ b/tensorflow/contrib/lite/kernels/div.cc
@@ -57,8 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
@@ -80,7 +80,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
void EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteDivParams* params, const OpData* data,
- TfLiteTensor* input1, TfLiteTensor* input2,
+ const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRangeFloat(params->activation, &output_activation_min,
@@ -106,15 +106,13 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
#undef TF_LITE_DIV
}
-
-
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteDivParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
- TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc
new file mode 100644
index 0000000000..b719a08394
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/elementwise.cc
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace elementwise {
+
+TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+ // Quantized float is not supported yet.
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ size_t elements = NumElements(input);
+ const float* in = GetTensorData<float>(input);
+ const float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; in++, out++) *out = std::sin(*in);
+ return kTfLiteOk;
+ }
+ default: {
+ context->ReportError(context, "Only float32 is supported currently");
+ return kTfLiteError;
+ }
+ }
+}
+
+} // namespace elementwise
+
+TfLiteRegistration* Register_SIN() {
+ static TfLiteRegistration r = {nullptr, nullptr, elementwise::SinPrepare,
+ elementwise::SinEval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc
new file mode 100644
index 0000000000..412ffb04b9
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class SinOpModel : public SingleOpModel {
+ public:
+ SinOpModel(std::initializer_list<int> input_shape) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_SIN, BuiltinOptions_NONE, 0);
+ BuildInterpreter({input_shape});
+ }
+
+ int input() const { return input_; }
+ int output() const { return output_; }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(ElementWise, Sin) {
+ SinOpModel m({1, 1, 4, 1});
+ m.PopulateTensor<float>(m.input(), {0, 3.1415926, -3.1415926, 1});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray(ArrayFloatNear({0, 0, 0, 0.84147})));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
index 4e8cb396d4..7539c0b30d 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -51,11 +51,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* lookup = GetInput(context, node, 0);
+ const TfLiteTensor* lookup = GetInput(context, node, 0);
TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
- TfLiteTensor* value = GetInput(context, node, 1);
+ const TfLiteTensor* value = GetInput(context, node, 1);
TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
TfLiteTensor* output = GetOutput(context, node, 0);
@@ -71,8 +71,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
- TfLiteTensor* lookup = GetInput(context, node, 0);
- TfLiteTensor* value = GetInput(context, node, 1);
+ const TfLiteTensor* lookup = GetInput(context, node, 0);
+ const TfLiteTensor* value = GetInput(context, node, 1);
const int row_size = SizeOfDimension(value, 0);
const int row_bytes = value->bytes / row_size;
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
index 6c770e7f71..d3be36993c 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
@@ -81,19 +81,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 5);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* ids = GetInput(context, node, 0);
+ const TfLiteTensor* ids = GetInput(context, node, 0);
TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1);
TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32);
- TfLiteTensor* indices = GetInput(context, node, 1);
+ const TfLiteTensor* indices = GetInput(context, node, 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2);
TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32);
- TfLiteTensor* shape = GetInput(context, node, 2);
+ const TfLiteTensor* shape = GetInput(context, node, 2);
TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1);
TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32);
- TfLiteTensor* weights = GetInput(context, node, 3);
+ const TfLiteTensor* weights = GetInput(context, node, 3);
TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1);
TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32);
@@ -102,7 +102,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
SizeOfDimension(weights, 0));
- TfLiteTensor* value = GetInput(context, node, 4);
+ const TfLiteTensor* value = GetInput(context, node, 4);
TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
// Mark the output as a dynamic tensor.
@@ -139,11 +139,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteEmbeddingLookupSparseParams*>(node->builtin_data);
TfLiteTensor* output = GetOutput(context, node, 0);
- TfLiteTensor* ids = GetInput(context, node, 0);
- TfLiteTensor* indices = GetInput(context, node, 1);
- TfLiteTensor* dense_shape = GetInput(context, node, 2);
- TfLiteTensor* weights = GetInput(context, node, 3);
- TfLiteTensor* value = GetInput(context, node, 4);
+ const TfLiteTensor* ids = GetInput(context, node, 0);
+ const TfLiteTensor* indices = GetInput(context, node, 1);
+ const TfLiteTensor* dense_shape = GetInput(context, node, 2);
+ const TfLiteTensor* weights = GetInput(context, node, 3);
+ const TfLiteTensor* value = GetInput(context, node, 4);
const int lookup_rank = SizeOfDimension(indices, 1);
const int embedding_rank = NumDimensions(value);
diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/contrib/lite/kernels/exp.cc
index a9e79b742d..ce03cdfe26 100644
--- a/tensorflow/contrib/lite/kernels/exp.cc
+++ b/tensorflow/contrib/lite/kernels/exp.cc
@@ -36,7 +36,7 @@ struct ExpContext {
input = GetInput(context, node, 0);
output = GetOutput(context, node, 0);
}
- TfLiteTensor* input;
+ const TfLiteTensor* input;
TfLiteTensor* output;
};
diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc
index 4b4395f711..697b777693 100644
--- a/tensorflow/contrib/lite/kernels/floor.cc
+++ b/tensorflow/contrib/lite/kernels/floor.cc
@@ -27,7 +27,7 @@ constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -38,7 +38,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
optimized_ops::Floor(GetTensorData<float>(input), GetTensorDims(input),
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index 470b52b7bc..39b108629a 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -89,8 +89,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 3);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
@@ -158,8 +158,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node,
TfLiteFullyConnectedParams* params, OpData* data,
- TfLiteTensor* input, TfLiteTensor* filter,
- TfLiteTensor* bias, TfLiteTensor* output) {
+ const TfLiteTensor* input, const TfLiteTensor* filter,
+ const TfLiteTensor* bias, TfLiteTensor* output) {
int total_input_size = 1;
for (int i = 0; i < input->dims->size; i++) {
total_input_size *= input->dims->data[i];
@@ -191,8 +191,10 @@ TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node,
TfLiteStatus EvalPieQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteFullyConnectedParams* params, OpData* data,
- TfLiteTensor* input, TfLiteTensor* filter,
- TfLiteTensor* bias, TfLiteTensor* input_quantized,
+ const TfLiteTensor* input,
+ const TfLiteTensor* filter,
+ const TfLiteTensor* bias,
+ TfLiteTensor* input_quantized,
TfLiteTensor* output) {
// Check the types for this hybrid Op.
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
@@ -271,8 +273,9 @@ TfLiteStatus EvalPieQuantized(TfLiteContext* context, TfLiteNode* node,
template <KernelType kernel_type>
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteFullyConnectedParams* params, OpData* data,
- TfLiteTensor* input, TfLiteTensor* filter,
- TfLiteTensor* bias, TfLiteTensor* output) {
+ const TfLiteTensor* input,
+ const TfLiteTensor* filter, const TfLiteTensor* bias,
+ TfLiteTensor* output) {
gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
int32_t input_offset = -input->params.zero_point;
@@ -311,8 +314,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
template <KernelType kernel_type>
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteFullyConnectedParams* params, OpData* data,
- TfLiteTensor* input, TfLiteTensor* filter,
- TfLiteTensor* bias, TfLiteTensor* output) {
+ const TfLiteTensor* input, const TfLiteTensor* filter,
+ const TfLiteTensor* bias, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRangeFloat(params->activation, &output_activation_min,
&output_activation_max);
@@ -342,8 +345,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc
index 0e4187d1ea..c452d3ebac 100644
--- a/tensorflow/contrib/lite/kernels/gather.cc
+++ b/tensorflow/contrib/lite/kernels/gather.cc
@@ -35,8 +35,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const auto* params =
reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* positions = GetInput(context, node, kInputPositions);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* positions = GetInput(context, node, kInputPositions);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Only INT32 positions are supported.
TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32);
@@ -81,8 +81,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* positions = GetInput(context, node, kInputPositions);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* positions = GetInput(context, node, kInputPositions);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const int input_rank = NumDimensions(input);
#define TF_LITE_GATHER(data_type, index_type) \
diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
index 3b82601d11..41211d41aa 100644
--- a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
@@ -60,15 +60,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
- TfLiteTensor* lookup = GetInput(context, node, 0);
+ const TfLiteTensor* lookup = GetInput(context, node, 0);
TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
- TfLiteTensor* key = GetInput(context, node, 1);
+ const TfLiteTensor* key = GetInput(context, node, 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1);
TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32);
- TfLiteTensor* value = GetInput(context, node, 2);
+ const TfLiteTensor* value = GetInput(context, node, 2);
TF_LITE_ENSURE(context, NumDimensions(value) >= 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0),
SizeOfDimension(value, 0));
@@ -102,9 +102,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
TfLiteTensor* hits = GetOutput(context, node, 1);
- TfLiteTensor* lookup = GetInput(context, node, 0);
- TfLiteTensor* key = GetInput(context, node, 1);
- TfLiteTensor* value = GetInput(context, node, 2);
+ const TfLiteTensor* lookup = GetInput(context, node, 0);
+ const TfLiteTensor* key = GetInput(context, node, 1);
+ const TfLiteTensor* value = GetInput(context, node, 2);
const int num_rows = SizeOfDimension(value, 0);
const int row_bytes = value->bytes / num_rows;
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index f142374269..5f9cfc450d 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+
+#include <algorithm>
+
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
namespace tflite {
@@ -40,6 +44,76 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
hidden_state_ptr_batch);
}
+void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
+ float input_weights_scale,
+ const int8_t* recurrent_weights_ptr,
+ float recurrent_weights_scale, const float* bias_ptr,
+ int input_size, int num_units, int batch_size,
+ TfLiteFusedActivation activation,
+ int8_t* quantized_input_ptr_batch,
+ int8_t* quantized_hidden_state_ptr_batch,
+ float* hidden_state_ptr_batch, float* output_ptr_batch) {
+ // Output = bias
+ tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
+ output_ptr_batch);
+
+ // TODO(mirkov): change std::minmax_element with a vectorized call.
+ auto minmax_element = std::minmax_element(
+ input_ptr_batch, input_ptr_batch + batch_size * input_size);
+
+ // Save quantization and matmul computation for all zero input.
+ if (!(*minmax_element.first == 0.0 && *minmax_element.second == 0.0)) {
+ // Quantize input from float to uint8 + quantization params (scaling
+ // factor).
+ float unused_min, unused_max;
+ float* scaling_factors = new float[batch_size];
+ for (int b = 0; b < batch_size; ++b) {
+ const int offset = b * input_size;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, input_size,
+ quantized_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ scaling_factors[b] *= input_weights_scale;
+ }
+
+ // Output += input * input_weights
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_weights_ptr, num_units, input_size, quantized_input_ptr_batch,
+ scaling_factors, batch_size, output_ptr_batch, /*result_stride=*/1);
+ delete[] scaling_factors;
+ }
+
+ minmax_element = std::minmax_element(
+ hidden_state_ptr_batch, hidden_state_ptr_batch + batch_size * num_units);
+ // Save quantization and matmul computation for all zero input.
+ if (!(*minmax_element.first == 0.0 && *minmax_element.second == 0.0)) {
+ // Quantize hidden_state
+ float unused_min, unused_max;
+ float* scaling_factors = new float[batch_size];
+ for (int b = 0; b < batch_size; ++b) {
+ const int offset = b * num_units;
+ tensor_utils::SymmetricQuantizeFloats(
+ hidden_state_ptr_batch + offset, num_units,
+ quantized_hidden_state_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ scaling_factors[b] *= recurrent_weights_scale;
+ }
+
+ // Output += recurrent_weights * hidden_state
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_weights_ptr, num_units, num_units,
+ quantized_hidden_state_ptr_batch, scaling_factors, batch_size,
+ output_ptr_batch, /*result_stride=*/1);
+ delete[] scaling_factors;
+ }
+
+ // Output = activation(Output) and update hidden_state
+ tensor_utils::ApplyActivationToVector(
+ output_ptr_batch, num_units * batch_size, activation, output_ptr_batch);
+ tensor_utils::VectorBatchVectorAssign(output_ptr_batch, num_units, batch_size,
+ hidden_state_ptr_batch);
+}
+
void LstmStep(
const float* input_ptr_batch, const float* input_to_input_weights_ptr,
const float* input_to_forget_weights_ptr,
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index 3ec60ee57a..cbfbcbeefc 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -35,6 +35,23 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
TfLiteFusedActivation activation,
float* hidden_state_ptr_batch, float* output_ptr_batch);
+// Performs a quantized RNN batch inference step. Same as above, but for
+// quantization purposes, we also pass in quantized_hidden_state_ptr_batch and
+// quantized_input_ptr_batch pointers for temporary storage of the quantized
+// values of hidden_state_ptr_batch and input_ptr_batch, respectively.
+// These temporary storages are expected to be preallocated to the same size as
+// the respective pointers.
+// {input,recurrent}_weights_scale params are used for dequantization/recovery.
+void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
+ float input_weights_scale,
+ const int8_t* recurrent_weights_ptr,
+ float recurrent_weights_scale, const float* bias_ptr,
+ int input_size, int num_units, int batch_size,
+ TfLiteFusedActivation activation,
+ int8_t* quantized_input_ptr_batch,
+ int8_t* quantized_hidden_state_ptr_batch,
+ float* hidden_state_ptr_batch, float* output_ptr_batch);
+
// Performs an LSTM batch inference step for input specified by input_ptr_batch.
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 65f25168e3..08f7cfa5a5 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -56,9 +56,12 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1));
// The arrays used to cache the vector.
+ void* aligned_vector_cache_free = nullptr;
float32x4_t* vector_cache_float32x4 =
- new float32x4_t[(m_cols / kFloatWeightsPerNeonLane) *
- sizeof(float32x4_t)];
+ reinterpret_cast<float32x4_t*>(aligned_alloc(
+ sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t),
+ &aligned_vector_cache_free));
+
const int kUnrollSize = 2;
for (int b = 0; b < n_batch; b++) {
float* result_in_batch = result + b * m_rows * result_stride;
@@ -71,7 +74,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
matrix_ptr1 = matrix + m_cols;
}
- // Cahce the vector.
+ // Cache the vector.
for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c);
}
@@ -128,7 +131,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
result_in_batch += result_stride;
}
}
- delete[] vector_cache_float32x4;
+ free(aligned_vector_cache_free);
}
void NeonMatrixBatchVectorMultiplyAccumulate(
@@ -294,9 +297,12 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
// The arrays used to cache the vector.
+ void* aligned_vector_cache_free = nullptr;
float32x4_t* vector_cache_float32x4 =
- new float32x4_t[(v_size / kFloatWeightsPerNeonLane) *
- sizeof(float32x4_t)];
+ reinterpret_cast<float32x4_t*>(aligned_alloc(
+ sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t),
+ &aligned_vector_cache_free));
+
for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
vector_cache_float32x4[v >> 2] = vld1q_f32(vector + v);
}
@@ -322,7 +328,7 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
result_ptr += v_size;
batch_vector_ptr += v_size;
}
- delete[] vector_cache_float32x4;
+ free(aligned_vector_cache_free);
}
void NeonSub1Vector(const float* vector, int v_size, float* result) {
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 580d208beb..25776244ba 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -48,6 +48,8 @@ using reference_ops::Greater;
using reference_ops::GreaterEqual;
using reference_ops::Less;
using reference_ops::LessEqual;
+using reference_ops::RankOneSelect;
+using reference_ops::Select;
// Make a local VectorMap typedef allowing to map a float array
// as a Eigen vector expression. The std::conditional here is to
@@ -2499,52 +2501,17 @@ inline void Add(const float* input1_data, const Dims<4>& input1_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Add(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-template <FusedActivationFunctionType Ac>
-inline void Add(int left_shift, const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const uint8* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier, int input2_shift,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- gemmlowp::ScopedProfilingLabel label("Add/8bit");
- /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3,
- output_dims, 3);
- /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2,
- output_dims, 2);
- /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1,
- output_dims, 1);
- /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0,
- output_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
-
+// Element-wise add that can often be used for inner loop of broadcast add as
+// well as the non-broadcast add.
+inline void AddElementwise(int size, int left_shift, const uint8* input1_data,
+ int32 input1_offset, int32 input1_multiplier,
+ int input1_shift, const uint8* input2_data,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data) {
int i = 0;
- const int size = input1_dims.sizes[3] * input1_dims.strides[3];
TFLITE_DCHECK_GT(input1_offset, -256);
TFLITE_DCHECK_GT(input2_offset, -256);
TFLITE_DCHECK_LT(input1_offset, 256);
@@ -2623,6 +2590,54 @@ inline void Add(int left_shift, const uint8* input1_data,
}
}
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier, int input2_shift,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ gemmlowp::ScopedProfilingLabel label("Add/8bit");
+ const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+
+ TFLITE_DCHECK_GT(input1_offset, -256);
+ TFLITE_DCHECK_GT(input2_offset, -256);
+ TFLITE_DCHECK_LT(input1_offset, 256);
+ TFLITE_DCHECK_LT(input2_offset, 256);
+ AddElementwise(flat_size, left_shift, input1_data, input1_offset,
+ input1_multiplier, input1_shift, input2_data, input2_offset,
+ input2_multiplier, input2_shift, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data);
+}
+
template <FusedActivationFunctionType Ac>
inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
int input1_shift, const int16* input2_data,
@@ -2833,27 +2848,11 @@ inline void BroadcastAddFivefold(
input2_data_ptr = input2_data_reset;
for (int i2 = 0; i2 < y2; ++i2) {
for (int i1 = 0; i1 < y1; ++i1) {
- for (int i0 = 0; i0 < y0; ++i0) {
- const int32 input1_val = input1_offset + input1_data_ptr[i0];
- const int32 input2_val = input2_offset + input2_data_ptr[i0];
- const int32 shifted_input1_val = input1_val * (1 << left_shift);
- const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
- const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
- const int32 raw_sum = scaled_input1_val + scaled_input2_val;
- const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sum, output_multiplier, output_shift) +
- output_offset;
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, raw_output));
- output_data_ptr[i0] = static_cast<uint8>(clamped_output);
- }
+ AddElementwise(
+ y0, left_shift, input1_data_ptr, input1_offset, input1_multiplier,
+ input1_shift, input2_data_ptr, input2_offset, input2_multiplier,
+ input2_shift, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data_ptr);
input2_data_ptr += y0;
output_data_ptr += y0;
}
@@ -6045,10 +6044,10 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims,
size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3];
const int start_h = begin[2];
const int stop_h =
- size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2];
+ size[2] == -1 ? input_dims.sizes[2] - start_h : start_h + size[2];
const int start_w = begin[1];
const int stop_w =
- size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1];
+ size[1] == -1 ? input_dims.sizes[1] - start_w : start_w + size[1];
const int start_d = begin[0];
const int stop_d =
size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0];
@@ -6318,59 +6317,6 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
}
}
-// UNOPTIMIZED COPY of Select from reference_ops.h.
-template <typename D, typename T>
-inline void Select(const D* input_condition_data,
- const Dims<4>& input_condition_dims, const T* input_x_data,
- const Dims<4>& input_x_dims, const T* input_y_data,
- const Dims<4>& input_y_dims, T* output_data,
- const Dims<4>& output_dims) {
- const int64_t batches =
- MatchingArraySize(input_condition_dims, 3, input_x_dims, 3, input_y_dims,
- 3, output_dims, 3);
- const int64_t height =
- MatchingArraySize(input_condition_dims, 2, input_x_dims, 2, input_y_dims,
- 2, output_dims, 2);
- const int64_t width = MatchingArraySize(input_condition_dims, 1, input_x_dims,
- 1, input_y_dims, 1, output_dims, 1);
- const int64_t depth = MatchingArraySize(input_condition_dims, 0, input_x_dims,
- 0, input_y_dims, 0, output_dims, 0);
-
- const int64_t num_elements = batches * height * width * depth;
- for (int64_t i = 0; i < num_elements; ++i) {
- output_data[i] =
- input_condition_data[i] ? input_x_data[i] : input_y_data[i];
- }
-}
-
-// UNOPTIMIZED COPY of RankOneSelect from reference_ops.h.
-template <typename D, typename T>
-inline void RankOneSelect(const D* input_condition_data,
- const Dims<4>& input_condition_dims,
- const T* input_x_data, const Dims<4>& input_x_dims,
- const T* input_y_data, const Dims<4>& input_y_dims,
- T* output_data, const Dims<4>& output_dims) {
- const int64_t rank = ArraySize(input_condition_dims, 0);
-
- const int64_t batches =
- MatchingArraySize(input_x_dims, 3, input_y_dims, 3, output_dims, 3);
- const int64_t height =
- MatchingArraySize(input_x_dims, 2, input_y_dims, 2, output_dims, 2);
- const int64_t width =
- MatchingArraySize(input_x_dims, 1, input_y_dims, 1, output_dims, 1);
- const int64_t depth =
- MatchingArraySize(input_x_dims, 0, input_y_dims, 0, output_dims, 0);
-
- TFLITE_DCHECK_EQ(rank, batches);
-
- int64_t offset = 0;
- int64_t size = depth * height * width;
- for (int64_t i = 0; i < rank; i++) {
- const T* input_data = input_condition_data[i] ? input_x_data : input_y_data;
- memcpy(output_data + offset, input_data + offset, size * sizeof(T));
- }
-}
-
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index e2978cfd67..db0802daed 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -1456,33 +1456,6 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
output_data, output_dims);
}
-inline void Div(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int batches =
- MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
- const int height =
- MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
- const int width =
- MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
- const int depth =
- MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < height; ++y) {
- for (int x = 0; x < width; ++x) {
- for (int c = 0; c < depth; ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- ActivationFunctionWithMinMax(
- input1_data[Offset(input1_dims, c, x, y, b)] /
- input2_data[Offset(input2_dims, c, x, y, b)],
- output_activation_min, output_activation_max);
- }
- }
- }
- }
-}
-
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
@@ -1524,6 +1497,18 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
}
}
+inline void Div(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] / input2_data[i], output_activation_min,
+ output_activation_max);
+ }
+}
+
inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
const float* input2_data, const Dims<4>& input2_dims,
float output_activation_min, float output_activation_max,
@@ -3256,10 +3241,10 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims,
size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3];
const int start_h = begin[2];
const int stop_h =
- size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2];
+ size[2] == -1 ? input_dims.sizes[2] - start_h : start_h + size[2];
const int start_w = begin[1];
const int stop_w =
- size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1];
+ size[1] == -1 ? input_dims.sizes[1] - start_w : start_w + size[1];
const int start_d = begin[0];
const int stop_d =
size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0];
@@ -3285,11 +3270,11 @@ inline void Exp(const T* input_data, const size_t num_elements,
}
template <typename T, typename U>
-inline bool Mean(T* input_data, const int* input_dims, const int input_num_dims,
- T* output_data, const int* output_dims,
- const int output_num_dims, const int* axis,
- const int num_axis_dimensions, bool keep_dims, int* temp_index,
- int* resolved_axis, U* temp_sum) {
+inline bool Mean(const T* input_data, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ const int* output_dims, const int output_num_dims,
+ const int* axis, const int num_axis_dimensions, bool keep_dims,
+ int* temp_index, int* resolved_axis, U* temp_sum) {
// resets output data.
size_t num_outputs = 1;
for (int idx = 0; idx < output_num_dims; ++idx) {
@@ -3621,7 +3606,7 @@ inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
}
}
-template <typename T, ComparisonFn<T> F>
+template <typename T, ComparisonFn<int32> F>
inline void Comparison(int left_shift, const T* input1_data,
const Dims<4>& input1_dims, int32 input1_offset,
int32 input1_multiplier, int input1_shift,
@@ -3672,7 +3657,7 @@ inline void BroadcastComparison(const T* input1_data,
}
}
-template <typename T, ComparisonFn<T> F>
+template <typename T, ComparisonFn<int32> F>
inline void BroadcastComparison(int left_shift, const T* input1_data,
const Dims<4>& input1_dims, int32 input1_offset,
int32 input1_multiplier, int input1_shift,
@@ -3724,11 +3709,11 @@ inline void BroadcastComparison(int left_shift, const T* input1_data,
int32 input2_multiplier, int input2_shift, bool* output_data, \
const Dims<4>& output_dims) { \
gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
- BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
- input1_offset, input1_multiplier, \
- input1_shift, input2_data, input2_dims, \
- input2_offset, input2_multiplier, \
- input2_shift, output_data, output_dims); \
+ Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, input1_shift, \
+ input2_data, input2_dims, input2_offset, \
+ input2_multiplier, input2_shift, output_data, \
+ output_dims); \
} \
template <typename T> \
inline void Broadcast##name( \
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index 62cea143e6..ce887cea8b 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -49,6 +49,34 @@ inline bool* GetTensorData(TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.b : nullptr;
}
+template <typename T>
+inline const T* GetTensorData(const TfLiteTensor* tensor);
+
+template <>
+inline const float* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline const int64_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i64 : nullptr;
+}
+
+template <>
+inline const bool* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
inline int RemapDim(int max_dimensions, int d) {
return max_dimensions - d - 1;
}
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc
index 955e8c5764..239b533a17 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.cc
+++ b/tensorflow/contrib/lite/kernels/kernel_util.cc
@@ -22,9 +22,12 @@ limitations under the License.
namespace tflite {
-TfLiteStatus GetQuantizedConvolutionMultipler(
- TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter,
- TfLiteTensor* bias, TfLiteTensor* output, double* multiplier) {
+TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
+ const TfLiteTensor* input,
+ const TfLiteTensor* filter,
+ const TfLiteTensor* bias,
+ TfLiteTensor* output,
+ double* multiplier) {
const double input_product_scale = input->params.scale * filter->params.scale;
const double bias_scale = bias->params.scale;
const double output_scale = output->params.scale;
@@ -87,13 +90,13 @@ void CalculateActivationRangeFloat(TfLiteFusedActivation activation,
}
}
-bool HaveSameShapes(TfLiteTensor* input1, TfLiteTensor* input2) {
+bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2) {
return TfLiteIntArrayEqual(input1->dims, input2->dims);
}
TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
- TfLiteTensor* input1,
- TfLiteTensor* input2,
+ const TfLiteTensor* input1,
+ const TfLiteTensor* input2,
TfLiteIntArray** output_shape) {
int64_t dims1 = NumDimensions(input1);
int64_t dims2 = NumDimensions(input2);
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h
index e225443a67..de0e368891 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.h
+++ b/tensorflow/contrib/lite/kernels/kernel_util.h
@@ -24,8 +24,8 @@ inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; }
inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
return t->dims->data[dim];
}
-inline TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node,
- int index) {
+inline const TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node,
+ int index) {
return &context->tensors[node->inputs->data[index]];
}
inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node,
@@ -78,9 +78,12 @@ inline void SetTensorToDynamic(TfLiteTensor* tensor) {
// Calculates the multiplication factor for a quantized convolution (or
// quantized depthwise convolution) involving the given tensors. Returns an
// error if the scales of the tensors are not compatible.
-TfLiteStatus GetQuantizedConvolutionMultipler(
- TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter,
- TfLiteTensor* bias, TfLiteTensor* output, double* multiplier);
+TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
+ const TfLiteTensor* input,
+ const TfLiteTensor* filter,
+ const TfLiteTensor* bias,
+ TfLiteTensor* output,
+ double* multiplier);
// Calculates the useful range of an activation layer given its activation
// tensor.
@@ -92,13 +95,13 @@ void CalculateActivationRangeFloat(TfLiteFusedActivation activation,
float* activation_max);
// Return true if the given tensors have the same shape.
-bool HaveSameShapes(TfLiteTensor* input1, TfLiteTensor* input2);
+bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2);
// Calculate the output_shape that is necessary for element-wise operations
// with broadcasting involving the two input tensors.
TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
- TfLiteTensor* input1,
- TfLiteTensor* input2,
+ const TfLiteTensor* input1,
+ const TfLiteTensor* input2,
TfLiteIntArray** output_shape);
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/kernel_util_test.cc b/tensorflow/contrib/lite/kernels/kernel_util_test.cc
index c65b68970f..bf6f249acc 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util_test.cc
+++ b/tensorflow/contrib/lite/kernels/kernel_util_test.cc
@@ -33,7 +33,7 @@ class KernelUtilTest : public ::testing::Test {
tensor1_.allocation_type = kTfLiteMmapRo;
tensor2_.allocation_type = kTfLiteMmapRo;
}
- ~KernelUtilTest() {
+ ~KernelUtilTest() override {
TfLiteTensorFree(&tensor1_);
TfLiteTensorFree(&tensor2_);
}
diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc
index e67f4e06f3..7cea63da87 100644
--- a/tensorflow/contrib/lite/kernels/l2norm.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm.cc
@@ -40,7 +40,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
@@ -64,7 +64,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc
index c1c70d0dfa..c15a5170b8 100644
--- a/tensorflow/contrib/lite/kernels/local_response_norm.cc
+++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc
@@ -38,7 +38,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@@ -60,7 +60,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteLocalResponseNormParams*>(node->builtin_data);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc
index 0ee35775d5..25d2dc2cdd 100644
--- a/tensorflow/contrib/lite/kernels/lsh_projection.cc
+++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc
@@ -77,16 +77,16 @@ TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* hash = GetInput(context, node, 0);
+ const TfLiteTensor* hash = GetInput(context, node, 0);
TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2);
// Support up to 32 bits.
TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32);
- TfLiteTensor* input = GetInput(context, node, 1);
+ const TfLiteTensor* input = GetInput(context, node, 1);
TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
if (NumInputs(node) == 3) {
- TfLiteTensor* weight = GetInput(context, node, 2);
+ const TfLiteTensor* weight = GetInput(context, node, 2);
TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0),
SizeOfDimension(input, 0));
@@ -173,9 +173,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
reinterpret_cast<TfLiteLSHProjectionParams*>(node->builtin_data);
int32_t* out_buf = GetOutput(context, node, 0)->data.i32;
- TfLiteTensor* hash = GetInput(context, node, 0);
- TfLiteTensor* input = GetInput(context, node, 1);
- TfLiteTensor* weight =
+ const TfLiteTensor* hash = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 1);
+ const TfLiteTensor* weight =
NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2);
switch (params->type) {
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index a1521efbb4..8d447a2dcf 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -100,13 +100,13 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
}
- TfLiteTensor* input_to_forget_weights =
+ const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, kInputToForgetWeightsTensor);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
- TfLiteTensor* input_to_cell_weights =
+ const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, kInputToCellWeightsTensor);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
@@ -122,7 +122,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
n_output);
}
- TfLiteTensor* recurrent_to_forget_weights =
+ const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, kRecurrentToForgetWeightsTensor);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
@@ -130,7 +130,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
n_output);
- TfLiteTensor* recurrent_to_cell_weights =
+ const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, kRecurrentToCellWeightsTensor);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
@@ -188,16 +188,16 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
}
- TfLiteTensor* forget_gate_bias =
+ const TfLiteTensor* forget_gate_bias =
GetInput(context, node, kForgetGateBiasTensor);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
- TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
- TfLiteTensor* output_gate_bias =
+ const TfLiteTensor* output_gate_bias =
GetInput(context, node, kOutputGateBiasTensor);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
@@ -241,18 +241,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input->dims->size > 1);
const int n_batch = input->dims->data[0];
const int n_input = input->dims->data[1];
- TfLiteTensor* input_to_output_weights =
+ const TfLiteTensor* input_to_output_weights =
GetInput(context, node, kInputToOutputWeightsTensor);
const int n_cell = input_to_output_weights->dims->data[0];
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
- TfLiteTensor* recurrent_to_output_weights =
+ const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, kRecurrentToOutputWeightsTensor);
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
@@ -322,24 +322,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// The LSTM Op engine.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
- TfLiteTensor* input_to_forget_weights =
+ const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, kInputToForgetWeightsTensor);
- TfLiteTensor* input_to_cell_weights =
+ const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, kInputToCellWeightsTensor);
- TfLiteTensor* input_to_output_weights =
+ const TfLiteTensor* input_to_output_weights =
GetInput(context, node, kInputToOutputWeightsTensor);
TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
- TfLiteTensor* recurrent_to_forget_weights =
+ const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, kRecurrentToForgetWeightsTensor);
- TfLiteTensor* recurrent_to_cell_weights =
+ const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, kRecurrentToCellWeightsTensor);
- TfLiteTensor* recurrent_to_output_weights =
+ const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, kRecurrentToOutputWeightsTensor);
TfLiteTensor* cell_to_input_weights =
@@ -351,10 +351,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* input_gate_bias =
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
- TfLiteTensor* forget_gate_bias =
+ const TfLiteTensor* forget_gate_bias =
GetInput(context, node, kForgetGateBiasTensor);
- TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
- TfLiteTensor* output_gate_bias =
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ const TfLiteTensor* output_gate_bias =
GetInput(context, node, kOutputGateBiasTensor);
TfLiteTensor* projection_weights =
diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum.cc b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
index 5a28d663c9..8d676218bd 100644
--- a/tensorflow/contrib/lite/kernels/maximum_minimum.cc
+++ b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
@@ -41,8 +41,8 @@ struct OpContext {
input2 = GetInput(context, node, kInputTensor2);
output = GetOutput(context, node, kOutputTensor);
}
- TfLiteTensor* input1;
- TfLiteTensor* input2;
+ const TfLiteTensor* input1;
+ const TfLiteTensor* input2;
TfLiteTensor* output;
};
diff --git a/tensorflow/contrib/lite/kernels/mean.cc b/tensorflow/contrib/lite/kernels/mean.cc
index 98f80e32d9..03e5db24de 100644
--- a/tensorflow/contrib/lite/kernels/mean.cc
+++ b/tensorflow/contrib/lite/kernels/mean.cc
@@ -40,8 +40,8 @@ struct MeanContext {
output = GetOutput(context, node, 0);
}
TfLiteMeanParams* params;
- TfLiteTensor* input;
- TfLiteTensor* axis;
+ const TfLiteTensor* input;
+ const TfLiteTensor* axis;
TfLiteTensor* output;
};
diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/contrib/lite/kernels/mfcc.cc
index 018db0dc54..3f5bc4d68a 100644
--- a/tensorflow/contrib/lite/kernels/mfcc.cc
+++ b/tensorflow/contrib/lite/kernels/mfcc.cc
@@ -67,8 +67,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* inputWav = GetInput(context, node, kInputTensorWav);
- TfLiteTensor* inputRate = GetInput(context, node, kInputTensorRate);
+ const TfLiteTensor* inputWav = GetInput(context, node, kInputTensorWav);
+ const TfLiteTensor* inputRate = GetInput(context, node, kInputTensorRate);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(inputWav), 3);
@@ -94,8 +94,8 @@ template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteMfccParams*>(node->user_data);
- TfLiteTensor* inputWav = GetInput(context, node, kInputTensorWav);
- TfLiteTensor* inputRate = GetInput(context, node, kInputTensorRate);
+ const TfLiteTensor* inputWav = GetInput(context, node, kInputTensorWav);
+ const TfLiteTensor* inputRate = GetInput(context, node, kInputTensorRate);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const int32 sample_rate = *GetTensorData<int>(inputRate);
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index 54575019de..6c4c3a1edc 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -57,8 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
@@ -80,7 +80,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
void EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteMulParams* params, const OpData* data,
- TfLiteTensor* input1, TfLiteTensor* input2,
+ const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRangeFloat(params->activation, &output_activation_min,
@@ -109,7 +109,7 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
template <KernelType kernel_type>
void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteMulParams* params, const OpData* data,
- TfLiteTensor* input1, TfLiteTensor* input2,
+ const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output) {
auto input1_offset = -input1->params.zero_point;
auto input2_offset = -input2->params.zero_point;
@@ -149,8 +149,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
- TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
diff --git a/tensorflow/contrib/lite/kernels/neg.cc b/tensorflow/contrib/lite/kernels/neg.cc
index 692da81727..b8b53f3402 100644
--- a/tensorflow/contrib/lite/kernels/neg.cc
+++ b/tensorflow/contrib/lite/kernels/neg.cc
@@ -27,7 +27,7 @@ constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
output->type = input->type;
@@ -44,7 +44,7 @@ void Negate(const T* in_data, int num_elements, T* out_data) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const int num_elements = NumElements(input);
switch (input->type) {
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 9e1e4658e9..b1eb6f76a4 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -46,8 +46,8 @@ struct PadContext {
dims = NumDimensions(input);
}
TfLiteTensor* constant_values;
- TfLiteTensor* input;
- TfLiteTensor* paddings;
+ const TfLiteTensor* input;
+ const TfLiteTensor* paddings;
TfLiteTensor* output;
int dims;
};
diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc
index 0bf27c34c1..645d9f4008 100644
--- a/tensorflow/contrib/lite/kernels/pooling.cc
+++ b/tensorflow/contrib/lite/kernels/pooling.cc
@@ -69,7 +69,7 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* output = GetOutput(context, node, 0);
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
@@ -122,7 +122,7 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, OpData* data,
- TfLiteTensor* input, TfLiteTensor* output) {
+ const TfLiteTensor* input, TfLiteTensor* output) {
float activation_min, activation_max;
CalculateActivationRangeFloat(params->activation, &activation_min,
&activation_max);
@@ -143,7 +143,7 @@ void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node,
template <KernelType kernel_type>
void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, OpData* data,
- TfLiteTensor* input, TfLiteTensor* output) {
+ const TfLiteTensor* input, TfLiteTensor* output) {
int32_t activation_min;
int32_t activation_max;
CalculateActivationRangeUint8(params->activation, output, &activation_min,
@@ -165,8 +165,8 @@ void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node,
template <KernelType kernel_type>
void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLitePoolParams* params, OpData* data, TfLiteTensor* input,
- TfLiteTensor* output) {
+ TfLitePoolParams* params, OpData* data,
+ const TfLiteTensor* input, TfLiteTensor* output) {
float activation_min, activation_max;
CalculateActivationRangeFloat(params->activation, &activation_min,
&activation_max);
@@ -187,7 +187,7 @@ void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
template <KernelType kernel_type>
void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, OpData* data,
- TfLiteTensor* input, TfLiteTensor* output) {
+ const TfLiteTensor* input, TfLiteTensor* output) {
int32_t activation_min;
int32_t activation_max;
CalculateActivationRangeUint8(params->activation, output, &activation_min,
@@ -209,8 +209,8 @@ void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node,
template <KernelType kernel_type>
void L2EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLitePoolParams* params, OpData* data, TfLiteTensor* input,
- TfLiteTensor* output) {
+ TfLitePoolParams* params, OpData* data,
+ const TfLiteTensor* input, TfLiteTensor* output) {
float activation_min, activation_max;
CalculateActivationRangeFloat(params->activation, &activation_min,
&activation_max);
@@ -236,7 +236,7 @@ TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, 0);
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
AverageEvalFloat<kernel_type>(context, node, params, data, input, output);
@@ -258,7 +258,7 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, 0);
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
MaxEvalFloat<kernel_type>(context, node, params, data, input, output);
@@ -279,7 +279,7 @@ TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output = GetOutput(context, node, 0);
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
L2EvalFloat<kernel_type>(context, node, params, data, input, output);
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 5df35aac62..0c7cfcaf10 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -87,6 +87,8 @@ TfLiteRegistration* Register_LESS_EQUAL();
TfLiteRegistration* Register_FLOOR();
TfLiteRegistration* Register_NEG();
TfLiteRegistration* Register_SELECT();
+TfLiteRegistration* Register_SLICE();
+TfLiteRegistration* Register_SIN();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -155,6 +157,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR());
AddBuiltin(BuiltinOperator_NEG, Register_NEG());
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT());
+ AddBuiltin(BuiltinOperator_SLICE, Register_SLICE());
+ AddBuiltin(BuiltinOperator_SIN, Register_SIN());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
@@ -163,29 +167,6 @@ BuiltinOpResolver::BuiltinOpResolver() {
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
}
-TfLiteRegistration* BuiltinOpResolver::FindOp(
- tflite::BuiltinOperator op) const {
- auto it = builtins_.find(op);
- return it != builtins_.end() ? it->second : nullptr;
-}
-
-TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op) const {
- auto it = custom_ops_.find(op);
- return it != custom_ops_.end() ? it->second : nullptr;
-}
-
-void BuiltinOpResolver::AddBuiltin(tflite::BuiltinOperator op,
- TfLiteRegistration* registration) {
- registration->builtin_code = op;
- builtins_.insert(std::make_pair(op, registration));
-}
-
-void BuiltinOpResolver::AddCustom(const char* name,
- TfLiteRegistration* registration) {
- registration->builtin_code = BuiltinOperator_CUSTOM;
- custom_ops_.insert(std::make_pair(std::string(name), registration));
-}
-
} // namespace builtin
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h
index b9cff0ae21..b928f1b302 100644
--- a/tensorflow/contrib/lite/kernels/register.h
+++ b/tensorflow/contrib/lite/kernels/register.h
@@ -23,24 +23,9 @@ namespace tflite {
namespace ops {
namespace builtin {
-class BuiltinOpResolver : public OpResolver {
+class BuiltinOpResolver : public MutableOpResolver {
public:
BuiltinOpResolver();
- TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override;
- TfLiteRegistration* FindOp(const char* op) const override;
- void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration);
- void AddCustom(const char* name, TfLiteRegistration* registration);
-
- private:
- struct BuiltinOperatorHasher {
- size_t operator()(const tflite::BuiltinOperator& x) const {
- return std::hash<size_t>()(static_cast<size_t>(x));
- }
- };
- std::unordered_map<tflite::BuiltinOperator, TfLiteRegistration*,
- BuiltinOperatorHasher>
- builtins_;
- std::unordered_map<std::string, TfLiteRegistration*> custom_ops_;
};
} // namespace builtin
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc
index 438f70d311..3287040695 100644
--- a/tensorflow/contrib/lite/kernels/reshape.cc
+++ b/tensorflow/contrib/lite/kernels/reshape.cc
@@ -35,7 +35,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Tensorflow's Reshape allows one of the shape components to have the
@@ -70,7 +70,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
memcpy(output->data.raw, input->data.raw, input->bytes);
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
index 9e3e19c09a..e4bd0f5b85 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -36,8 +36,10 @@ constexpr int kInputTensor = 0;
constexpr int kSizeTensor = 1;
constexpr int kOutputTensor = 0;
-TfLiteStatus ResizeOutputTensor(TfLiteContext* context, TfLiteTensor* input,
- TfLiteTensor* size, TfLiteTensor* output) {
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
+ const TfLiteTensor* input,
+ const TfLiteTensor* size,
+ TfLiteTensor* output) {
TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
output_size->data[0] = input->dims->data[0];
const int32* size_data = GetTensorData<int32>(size);
@@ -51,8 +53,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* size = GetInput(context, node, kSizeTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// TODO(ahentz): Our current implementations rely on the inputs being 4D.
@@ -78,9 +80,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TfLiteTensor* size = GetInput(context, node, kSizeTensor);
+ const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context,
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc
index 029ad9a709..9bc8a1a34a 100644
--- a/tensorflow/contrib/lite/kernels/select.cc
+++ b/tensorflow/contrib/lite/kernels/select.cc
@@ -33,10 +33,10 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input_condition =
+ const TfLiteTensor* input_condition =
GetInput(context, node, kInputTensorCondition);
- TfLiteTensor* input_x = GetInput(context, node, kInputTensorX);
- TfLiteTensor* input_y = GetInput(context, node, kInputTensorY);
+ const TfLiteTensor* input_x = GetInput(context, node, kInputTensorX);
+ const TfLiteTensor* input_y = GetInput(context, node, kInputTensorY);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Input must be bool.
@@ -62,10 +62,10 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input_condition =
+ const TfLiteTensor* input_condition =
GetInput(context, node, kInputTensorCondition);
- TfLiteTensor* input_x = GetInput(context, node, kInputTensorX);
- TfLiteTensor* input_y = GetInput(context, node, kInputTensorY);
+ const TfLiteTensor* input_x = GetInput(context, node, kInputTensorX);
+ const TfLiteTensor* input_y = GetInput(context, node, kInputTensorY);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
bool is_rank_one = !HaveSameShapes(input_condition, input_x);
diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc
new file mode 100644
index 0000000000..b28934e2f7
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/slice.cc
@@ -0,0 +1,199 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string.h>
+#include <cmath>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace slice {
+
+constexpr int kInputTensor = 0;
+constexpr int kBeginTensor = 1;
+constexpr int kSizeTensor = 2;
+constexpr int kOutputTensor = 0;
+
+// This Op only supports 1-4D cases and since we use the optimized ops 4D
+// implementation, the 1-3D tensors are mapped to 4D.
+const int kMaxDim = 4;
+
+template <typename T>
+TfLiteStatus CalculateOutputShapeVector(
+ TfLiteContext* context, const TfLiteTensor* input,
+ const TfLiteTensor* begin, const TfLiteTensor* size,
+ std::vector<int64_t>* output_shape_vector) {
+ for (int idx = 0; idx < NumDimensions(input); ++idx) {
+ T size_value = GetTensorData<T>(size)[idx];
+ if (size_value < 0) {
+ if (size_value != -1) {
+ context->ReportError(context, "Invalid size.");
+ return kTfLiteError;
+ }
+ size_value = SizeOfDimension(input, idx) - GetTensorData<T>(begin)[idx];
+ } else {
+ if (SizeOfDimension(input, idx) <
+ GetTensorData<T>(begin)[idx] + size_value) {
+ context->ReportError(context, "Invalid begin and size.");
+ return kTfLiteError;
+ }
+ }
+ output_shape_vector->push_back(size_value);
+ }
+ return kTfLiteOk;
+}
+
+template <typename T>
+void GetBeginAndSizeVectors(int dimensions, const TfLiteTensor* begin,
+ const TfLiteTensor* size, std::vector<int>* begins,
+ std::vector<int>* sizes) {
+ for (int idx = dimensions - 1; idx >= 0; --idx) {
+ begins->push_back(GetTensorData<T>(begin)[idx]);
+ sizes->push_back(GetTensorData<T>(size)[idx]);
+ }
+}
+
+TfLiteStatus ResizeOutputShape(TfLiteContext* context,
+ const TfLiteTensor* input,
+ const TfLiteTensor* begin,
+ const TfLiteTensor* size, TfLiteTensor* output) {
+ std::vector<int64_t> output_shape_vector;
+
+ if (begin->type == kTfLiteInt32) {
+ TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector<int32_t>(
+ context, input, begin, size, &output_shape_vector));
+ } else if (begin->type == kTfLiteInt64) {
+ TF_LITE_ENSURE_STATUS(CalculateOutputShapeVector<int64_t>(
+ context, input, begin, size, &output_shape_vector));
+ } else {
+ context->ReportError(context, "Type is currently not supported by Slice.");
+ return kTfLiteError;
+ }
+
+ TfLiteIntArray* output_shape =
+ TfLiteIntArrayCreate(output_shape_vector.size());
+ std::copy(output_shape_vector.begin(), output_shape_vector.end(),
+ output_shape->data);
+ return context->ResizeTensor(context, output, output_shape);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* begin = GetInput(context, node, kBeginTensor);
+ const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // Ensure validity of input tensor and its dimension.
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+ TF_LITE_ENSURE(context,
+ begin->type == kTfLiteInt32 || begin->type == kTfLiteInt64);
+ TF_LITE_ENSURE(context,
+ size->type == kTfLiteInt32 || size->type == kTfLiteInt64);
+ TF_LITE_ENSURE(context, NumDimensions(begin) == NumDimensions(size) == 1);
+ TF_LITE_ENSURE_MSG(context, NumDimensions(input) <= kMaxDim,
+ "Slice op only supports 1D-4D input arrays.");
+
+ // Postpone allocation of output if any of the indexing tensors is not
+ // constant
+ if (!(IsConstantTensor(begin) && IsConstantTensor(size))) {
+ SetTensorToDynamic(output);
+ return kTfLiteOk;
+ }
+
+ return ResizeOutputShape(context, input, begin, size, output);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* begin = GetInput(context, node, kBeginTensor);
+ const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (IsDynamicTensor(output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeOutputShape(context, input, begin, size, output));
+ }
+
+ std::vector<int> begins;
+ begins.reserve(kMaxDim);
+ std::vector<int> sizes;
+ sizes.reserve(kMaxDim);
+
+ if (begin->type == kTfLiteInt32) {
+ GetBeginAndSizeVectors<int32_t>(NumDimensions(input), begin, size, &begins,
+ &sizes);
+ } else if (begin->type == kTfLiteInt64) {
+ GetBeginAndSizeVectors<int64_t>(NumDimensions(input), begin, size, &begins,
+ &sizes);
+ } else {
+ context->ReportError(context, "Type is currently not supported by Slice.");
+ return kTfLiteError;
+ }
+
+ for (int i = NumDimensions(input); i < kMaxDim; ++i) {
+ begins.push_back(0);
+ sizes.push_back(1);
+ }
+
+#define TF_LITE_SLICE(data_type) \
+ optimized_ops::Slice<data_type>( \
+ GetTensorData<data_type>(input), GetTensorDims(input), begins, sizes, \
+ GetTensorData<data_type>(output), GetTensorDims(output))
+
+ switch (input->type) {
+ case kTfLiteFloat32:
+ TF_LITE_SLICE(float);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_SLICE(int32_t);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_SLICE(int64_t);
+ break;
+ case kTfLiteUInt8:
+ TF_LITE_SLICE(uint8_t);
+ break;
+ case kTfLiteBool:
+ TF_LITE_SLICE(bool);
+ break;
+ default:
+ context->ReportError(context,
+ "Type is currently not supported by Slice.");
+ return kTfLiteError;
+ }
+#undef TF_LITE_SLICE
+ return kTfLiteOk;
+}
+
+} // namespace slice
+
+TfLiteRegistration* Register_SLICE() {
+ static TfLiteRegistration r = {nullptr, nullptr, slice::Prepare, slice::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/slice_test.cc b/tensorflow/contrib/lite/kernels/slice_test.cc
new file mode 100644
index 0000000000..4828f88f36
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/slice_test.cc
@@ -0,0 +1,173 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+template <typename input_type, typename index_type>
+class SliceOpModel : public SingleOpModel {
+ public:
+ SliceOpModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> begin_shape,
+ std::initializer_list<int> size_shape,
+ TensorType tensor_index_type, TensorType tensor_input_type) {
+ input_ = AddInput(tensor_input_type);
+ begin_ = AddInput(tensor_index_type);
+ size_ = AddInput(tensor_index_type);
+ output_ = AddOutput(tensor_input_type);
+ SetBuiltinOp(BuiltinOperator_SLICE, BuiltinOptions_SliceOptions,
+ CreateSliceOptions(builder_).Union());
+ BuildInterpreter({input_shape, begin_shape, size_shape});
+ }
+
+ void SetInput(std::initializer_list<input_type> data) {
+ PopulateTensor<input_type>(input_, data);
+ }
+ void SetBegin(std::initializer_list<index_type> data) {
+ PopulateTensor<index_type>(begin_, data);
+ }
+ void SetSize(std::initializer_list<index_type> data) {
+ PopulateTensor<index_type>(size_, data);
+ }
+
+ std::vector<input_type> GetOutput() {
+ return ExtractVector<input_type>(output_);
+ }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int begin_;
+ int size_;
+ int output_;
+};
+
+TEST(SliceOpTest, In1D) {
+ SliceOpModel<float, int32_t> m({4}, {1}, {1}, TensorType_INT32,
+ TensorType_FLOAT32);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({1});
+ m.SetSize({2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
+}
+
+TEST(SliceOpTest, In2D) {
+ SliceOpModel<float, int32_t> m({2, 3}, {2}, {2}, TensorType_INT32,
+ TensorType_FLOAT32);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({1, 0});
+ m.SetSize({1, 2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5}));
+}
+
+TEST(SliceOpTest, In3D) {
+ SliceOpModel<float, int32_t> m({2, 3, 2}, {3}, {4}, TensorType_INT32,
+ TensorType_FLOAT32);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({0, 0, 0});
+ m.SetSize({2, 3, 2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2}));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
+}
+
+TEST(SliceOpTest, InputFloat) {
+ SliceOpModel<float, int32_t> m({4, 1, 1, 1}, {4}, {4}, TensorType_INT32,
+ TensorType_FLOAT32);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({1, 0, 0, 0});
+ m.SetSize({3, 1, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
+}
+
+TEST(SliceOpTest, IndexInt64) {
+ SliceOpModel<float, int64_t> m({4, 1, 1, 1}, {4}, {4}, TensorType_INT64,
+ TensorType_FLOAT32);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({1, 0, 0, 0});
+ m.SetSize({3, 1, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 1, 1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
+}
+
+// See these test cases under:
+// https://www.tensorflow.org/versions/master/api_docs/python/tf/slice
+TEST(SliceOpTest, InputInteger1) {
+ SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
+ TensorType_INT32);
+ m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
+ m.SetBegin({1, 0, 0, 0});
+ m.SetSize({1, 1, 3, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 3, 1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3}));
+}
+
+TEST(SliceOpTest, InputInteger2) {
+ SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
+ TensorType_INT32);
+ m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
+ m.SetBegin({1, 0, 0, 0});
+ m.SetSize({1, 2, 3, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 4, 4, 4}));
+}
+
+TEST(SliceOpTest, InputInteger3) {
+ SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
+ TensorType_INT32);
+ m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
+ m.SetBegin({1, 0, 0, 0});
+ m.SetSize({2, 1, 3, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
+}
+
+TEST(SliceOpTest, SizeMinus1) {
+ SliceOpModel<int32_t, int32_t> m({3, 2, 3, 1}, {4}, {4}, TensorType_INT32,
+ TensorType_INT32);
+ m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
+ m.SetBegin({1, 0, 0, 0});
+ m.SetSize({2, 1, -1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
index d8c9e352f0..1e35869958 100644
--- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
@@ -40,9 +40,9 @@ struct SpaceToBatchNDContext {
paddings = GetInput(context, node, 2);
output = GetOutput(context, node, 0);
}
- TfLiteTensor* input;
- TfLiteTensor* block_shape;
- TfLiteTensor* paddings;
+ const TfLiteTensor* input;
+ const TfLiteTensor* block_shape;
+ const TfLiteTensor* paddings;
TfLiteTensor* output;
};
diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc
index cb2e509c98..aafce89512 100644
--- a/tensorflow/contrib/lite/kernels/space_to_depth.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc
@@ -42,7 +42,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@@ -76,7 +76,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteSpaceToDepthParams*>(node->builtin_data);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc
index b524c79f87..c6b94c25be 100644
--- a/tensorflow/contrib/lite/kernels/split.cc
+++ b/tensorflow/contrib/lite/kernels/split.cc
@@ -34,8 +34,8 @@ struct OpContext {
input = GetInput(context, node, 1);
}
TfLiteSplitParams* params;
- TfLiteTensor* axis;
- TfLiteTensor* input;
+ const TfLiteTensor* axis;
+ const TfLiteTensor* input;
};
TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
@@ -46,8 +46,8 @@ TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
- TfLiteTensor* axis, TfLiteTensor* input,
- int num_splits) {
+ const TfLiteTensor* axis,
+ const TfLiteTensor* input, int num_splits) {
int axis_value = GetTensorData<int>(axis)[0];
if (axis_value < 0) {
axis_value += NumDimensions(input);
diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/contrib/lite/kernels/squeeze.cc
index 29447ab021..09a5662fd9 100644
--- a/tensorflow/contrib/lite/kernels/squeeze.cc
+++ b/tensorflow/contrib/lite/kernels/squeeze.cc
@@ -26,13 +26,12 @@ namespace builtin {
namespace squeeze {
struct SqueezeContext {
- SqueezeContext(TfLiteContext* context, TfLiteNode* node) {
- params = reinterpret_cast<TfLiteSqueezeParams*>(node->builtin_data);
- input = GetInput(context, node, 0);
- output = GetOutput(context, node, 0);
- }
+ SqueezeContext(TfLiteContext* context, TfLiteNode* node)
+ : params(reinterpret_cast<TfLiteSqueezeParams*>(node->builtin_data)),
+ input(GetInput(context, node, 0)),
+ output(GetOutput(context, node, 0)) {}
TfLiteSqueezeParams* params;
- TfLiteTensor* input;
+ const TfLiteTensor* const input;
TfLiteTensor* output;
};
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index 40ac436b7d..9417be32b3 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -49,10 +49,10 @@ struct StridedSliceContext {
dims = NumDimensions(input);
}
const TfLiteStridedSliceParams* params;
- TfLiteTensor* input;
- TfLiteTensor* begin;
- TfLiteTensor* end;
- TfLiteTensor* strides;
+ const TfLiteTensor* input;
+ const TfLiteTensor* begin;
+ const TfLiteTensor* end;
+ const TfLiteTensor* strides;
TfLiteTensor* output;
int dims;
};
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc
index 7c60a4fdbf..9531ecba98 100644
--- a/tensorflow/contrib/lite/kernels/sub.cc
+++ b/tensorflow/contrib/lite/kernels/sub.cc
@@ -57,8 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
@@ -80,7 +80,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
void EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteSubParams* params, const OpData* data,
- TfLiteTensor* input1, TfLiteTensor* input2,
+ const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRangeFloat(params->activation, &output_activation_min,
@@ -109,7 +109,7 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
template <KernelType kernel_type>
void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteSubParams* params, const OpData* data,
- TfLiteTensor* input1, TfLiteTensor* input2,
+ const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output) {
auto input1_offset = -input1->params.zero_point;
auto input2_offset = -input2->params.zero_point;
@@ -164,8 +164,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
- TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
- TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc
index 13da51c7a7..788812755e 100644
--- a/tensorflow/contrib/lite/kernels/svdf.cc
+++ b/tensorflow/contrib/lite/kernels/svdf.cc
@@ -58,9 +58,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
- TfLiteTensor* weights_feature =
+ const TfLiteTensor* weights_feature =
GetInput(context, node, kWeightsFeatureTensor);
- TfLiteTensor* weights_time = GetInput(context, node, kWeightsTimeTensor);
+ const TfLiteTensor* weights_time =
+ GetInput(context, node, kWeightsTimeTensor);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@@ -123,10 +124,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* weights_feature =
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* weights_feature =
GetInput(context, node, kWeightsFeatureTensor);
- TfLiteTensor* weights_time = GetInput(context, node, kWeightsTimeTensor);
+ const TfLiteTensor* weights_time =
+ GetInput(context, node, kWeightsTimeTensor);
TfLiteTensor* state = GetOutput(context, node, kStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
index 5a6c85e97e..1a01ee0936 100644
--- a/tensorflow/contrib/lite/kernels/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -101,7 +101,7 @@ void SingleOpModel::BuildInterpreter(
}
resolver_ = std::unique_ptr<OpResolver>(resolver);
}
- InterpreterBuilder(model, *resolver_)(&interpreter_);
+ CHECK(InterpreterBuilder(model, *resolver_)(&interpreter_) == kTfLiteOk);
CHECK(interpreter_ != nullptr);
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
index 6a9fdf1112..55edc97d19 100644
--- a/tensorflow/contrib/lite/kernels/test_util.h
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -89,18 +89,24 @@ struct TensorData {
class SingleOpResolver : public OpResolver {
public:
SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration)
- : op_(op), registration_(registration) {}
- TfLiteRegistration* FindOp(BuiltinOperator op) const override {
+ : op_(op), registration_(*registration) {
+ registration_.builtin_code = static_cast<int32_t>(op);
+ registration_.version = 1;
+ }
+ const TfLiteRegistration* FindOp(BuiltinOperator op,
+ int version) const override {
if (op == op_) {
- return registration_;
+ return &registration_;
}
return nullptr;
}
- TfLiteRegistration* FindOp(const char* op) const override { return nullptr; }
+ const TfLiteRegistration* FindOp(const char* op, int version) const override {
+ return nullptr;
+ }
private:
const BuiltinOperator op_;
- TfLiteRegistration* registration_;
+ TfLiteRegistration registration_;
};
class SingleOpModel {
diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc
index ad9b744f1a..b331fc8482 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2.cc
@@ -30,7 +30,7 @@ constexpr int kOutputIndexes = 1;
namespace {
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
+ const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
// INT32 number of top results is supported.
TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32);
// Check that the tensor contains only one value.
@@ -38,7 +38,7 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1);
const int32 k = top_k->data.i32[0];
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const int num_dimensions = NumDimensions(input);
// Check that input has one or more dimensions.
TF_LITE_ENSURE_MSG(context, input->dims->size >= 1,
@@ -162,11 +162,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
TF_LITE_ENSURE_EQ(context, input->type, output_values->type);
- TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
+ const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32);
// Set output dynamic if the input is not const.
@@ -187,11 +187,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (IsDynamicTensor(output_values)) {
TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
}
- TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
+ const TfLiteTensor* top_k = GetInput(context, node, kInputTopK);
const int32 k = top_k->data.i32[0];
// The tensor can have more than 2 dimensions or even be a vector, the code
// anyway calls the internal dimension as row;
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const int32 row_size = input->dims->data[input->dims->size - 1];
int32 num_rows = 1;
for (int i = 0; i < input->dims->size - 1; ++i) {
diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc
index d3c10a9bb7..8316a23c18 100644
--- a/tensorflow/contrib/lite/kernels/transpose.cc
+++ b/tensorflow/contrib/lite/kernels/transpose.cc
@@ -37,8 +37,8 @@ struct TransposeContext {
perm = GetInput(context, node, 1);
output = GetOutput(context, node, 0);
}
- TfLiteTensor* input;
- TfLiteTensor* perm;
+ const TfLiteTensor* input;
+ const TfLiteTensor* perm;
TfLiteTensor* output;
};
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 5987bf68b5..46d65ca8f8 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -100,13 +100,13 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
}
- TfLiteTensor* input_to_forget_weights =
+ const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, kInputToForgetWeightsTensor);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
- TfLiteTensor* input_to_cell_weights =
+ const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, kInputToCellWeightsTensor);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
@@ -122,7 +122,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
n_output);
}
- TfLiteTensor* recurrent_to_forget_weights =
+ const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, kRecurrentToForgetWeightsTensor);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
@@ -130,7 +130,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
n_output);
- TfLiteTensor* recurrent_to_cell_weights =
+ const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, kRecurrentToCellWeightsTensor);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
@@ -188,16 +188,16 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
}
- TfLiteTensor* forget_gate_bias =
+ const TfLiteTensor* forget_gate_bias =
GetInput(context, node, kForgetGateBiasTensor);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
- TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
- TfLiteTensor* output_gate_bias =
+ const TfLiteTensor* output_gate_bias =
GetInput(context, node, kOutputGateBiasTensor);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
@@ -241,19 +241,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input->dims->size > 1);
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
- TfLiteTensor* input_to_output_weights =
+ const TfLiteTensor* input_to_output_weights =
GetInput(context, node, kInputToOutputWeightsTensor);
const int n_cell = input_to_output_weights->dims->data[0];
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
- TfLiteTensor* recurrent_to_output_weights =
+ const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, kRecurrentToOutputWeightsTensor);
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
@@ -324,24 +324,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// The LSTM Op engine.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
- TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
- TfLiteTensor* input_to_forget_weights =
+ const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, kInputToForgetWeightsTensor);
- TfLiteTensor* input_to_cell_weights =
+ const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, kInputToCellWeightsTensor);
- TfLiteTensor* input_to_output_weights =
+ const TfLiteTensor* input_to_output_weights =
GetInput(context, node, kInputToOutputWeightsTensor);
TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
- TfLiteTensor* recurrent_to_forget_weights =
+ const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, kRecurrentToForgetWeightsTensor);
- TfLiteTensor* recurrent_to_cell_weights =
+ const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, kRecurrentToCellWeightsTensor);
- TfLiteTensor* recurrent_to_output_weights =
+ const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, kRecurrentToOutputWeightsTensor);
TfLiteTensor* cell_to_input_weights =
@@ -353,10 +353,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* input_gate_bias =
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
- TfLiteTensor* forget_gate_bias =
+ const TfLiteTensor* forget_gate_bias =
GetInput(context, node, kForgetGateBiasTensor);
- TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
- TfLiteTensor* output_gate_bias =
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ const TfLiteTensor* output_gate_bias =
GetInput(context, node, kOutputGateBiasTensor);
TfLiteTensor* projection_weights =
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index ac00c37b67..22c80df19c 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -38,17 +39,26 @@ constexpr int kBiasTensor = 3;
constexpr int kHiddenStateTensor = 0;
constexpr int kOutputTensor = 1;
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
- TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
- TfLiteTensor* input_weights =
- &context->tensors[node->inputs->data[kWeightsTensor]];
- TfLiteTensor* recurrent_weights =
- &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
- TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
+ const TfLiteTensor* recurrent_weights =
+ GetInput(context, node, kRecurrentWeightsTensor);
+ const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@@ -63,10 +73,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]);
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]);
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type);
- TfLiteTensor* hidden_state =
- &context->tensors[node->outputs->data[kHiddenStateTensor]];
- TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
+ TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Resize state.
TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
@@ -86,22 +97,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size_array));
+ // Allocate temporary tensors to store quantized values of input and
+ // hidden_state tensors.
+ if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(2);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+ node->temporaries->data[1] = *scratch_tensor_index + 1;
+ TfLiteTensor* hidden_state_quantized =
+ GetTemporary(context, node, /*index=*/1);
+ hidden_state_quantized->type = kTfLiteUInt8;
+ hidden_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(hidden_state_quantized->dims,
+ hidden_state->dims)) {
+ TfLiteIntArray* hidden_state_quantized_size =
+ TfLiteIntArrayCopy(hidden_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, hidden_state_quantized,
+ hidden_state_quantized_size));
+ }
+ }
return kTfLiteOk;
}
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
-
- TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
- TfLiteTensor* input_weights =
- &context->tensors[node->inputs->data[kWeightsTensor]];
- TfLiteTensor* recurrent_weights =
- &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
- TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
- TfLiteTensor* hidden_state =
- &context->tensors[node->outputs->data[kHiddenStateTensor]];
- TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
-
+TfLiteStatus EvalFloat(const TfLiteTensor* input,
+ const TfLiteTensor* input_weights,
+ const TfLiteTensor* recurrent_weights,
+ const TfLiteTensor* bias,
+ const TfLiteSequenceRNNParams* params,
+ TfLiteTensor* hidden_state, TfLiteTensor* output) {
// Initialize the pointer bias.
const float* bias_ptr = bias->data.f;
@@ -120,7 +153,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (time_major) {
// Initialize the pointer to hidden state.
float* hidden_state_ptr_batch = hidden_state->data.f;
- // Unroll the sequence and use batch batch operations for efficiency.
+ // Unroll the sequence and use batch operations for efficiency.
for (int s = 0; s < max_time; s++) {
// Initialize the pointer to input and output.
const float* input_ptr_batch =
@@ -154,12 +187,114 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+TfLiteStatus EvalQuantized(const TfLiteTensor* input,
+ const TfLiteTensor* input_weights,
+ const TfLiteTensor* recurrent_weights,
+ const TfLiteTensor* bias,
+ const TfLiteSequenceRNNParams* params,
+ TfLiteTensor* input_scratch,
+ TfLiteTensor* hidden_state_scratch,
+ TfLiteTensor* hidden_state, TfLiteTensor* output) {
+ const bool time_major = params->time_major;
+ const int batch_size =
+ (time_major) ? input->dims->data[1] : input->dims->data[0];
+ const int max_time =
+ (time_major) ? input->dims->data[0] : input->dims->data[1];
+ const int num_units = input_weights->dims->data[0];
+ const int input_size = input->dims->data[2];
+
+ // Initialize the pointer bias.
+ const float* bias_ptr = bias->data.f;
+ // Initialize input_weights and recurrent_weights.
+ const int8_t* input_weights_ptr =
+ reinterpret_cast<const int8_t*>(input_weights->data.uint8);
+ const int8_t* recurrent_weights_ptr =
+ reinterpret_cast<const int8_t*>(recurrent_weights->data.uint8);
+ // Get the scale of the quantized weights.
+ float input_weights_scale = input_weights->params.scale;
+ float recurrent_weights_scale = recurrent_weights->params.scale;
+ // Initialize temporary storage for quantized values.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_scratch->data.uint8);
+ int8_t* quantized_hidden_state_ptr =
+ reinterpret_cast<int8_t*>(hidden_state_scratch->data.uint8);
+
+ if (time_major) {
+ // Initialize the pointer to hidden state.
+ float* hidden_state_ptr_batch = hidden_state->data.f;
+ // Unroll the sequence and use batch operations for efficiency.
+ for (int s = 0; s < max_time; s++) {
+ // Initialize the pointer to input and output.
+ const float* input_ptr_batch =
+ input->data.f + s * input_size * batch_size;
+ float* output_ptr_batch = output->data.f + s * num_units * batch_size;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, input_weights_ptr, input_weights_scale,
+ recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
+ num_units, batch_size, params->activation, quantized_input_ptr,
+ quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch);
+ }
+ } else {
+ // For each batch
+ for (int b = 0; b < batch_size; b++) {
+ // Initialize the pointer to hidden state.
+ float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
+ for (int s = 0; s < max_time; s++) {
+ // Initialize the pointer to input and output.
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ float* output_ptr_batch =
+ output->data.f + b * num_units * max_time + s * num_units;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, input_weights_ptr, input_weights_scale,
+ recurrent_weights_ptr, recurrent_weights_scale, bias_ptr,
+ input_size, num_units, /*batch_size=*/1, params->activation,
+ quantized_input_ptr, quantized_hidden_state_ptr,
+ hidden_state_ptr_batch, output_ptr_batch);
+ }
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
+ const TfLiteTensor* recurrent_weights =
+ GetInput(context, node, kRecurrentWeightsTensor);
+ const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (input_weights->type) {
+ case kTfLiteFloat32:
+ return EvalFloat(input, input_weights, recurrent_weights, bias, params,
+ hidden_state, output);
+ case kTfLiteUInt8: {
+ // TODO(mirkov): implement eval with quantized inputs as well.
+ TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
+ TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
+ return EvalQuantized(input, input_weights, recurrent_weights, bias,
+ params, input_quantized, hidden_state_quantized,
+ hidden_state, output);
+ }
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
} // namespace unidirectional_sequence_rnn
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- unidirectional_sequence_rnn::Prepare,
- unidirectional_sequence_rnn::Eval};
+ static TfLiteRegistration r = {
+ unidirectional_sequence_rnn::Init, unidirectional_sequence_rnn::Free,
+ unidirectional_sequence_rnn::Prepare, unidirectional_sequence_rnn::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
index 7e32969763..0adab837b0 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
@@ -122,17 +122,66 @@ static float rnn_golden_output[] = {
0, 2.02616, 0, 0.728256, 0.84183, 0.0907453,
0.628881, 3.58099, 1.49974, 0};
+static std::initializer_list<float> rnn_weights = {
+ 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
+ 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
+ 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
+ -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
+ -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
+ -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
+ -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
+ 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
+ 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
+ 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
+ -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
+ 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
+ -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
+ -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
+ 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
+ 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
+ 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
+ -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
+ 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
+ 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
+ -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
+ 0.277308, 0.415818};
+
+static std::initializer_list<float> rnn_recurrent_weights = {
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1};
+
+static std::initializer_list<float> rnn_bias = {
+ 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568,
+ -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178,
+ 0.37197268, 0.61957061, 0.3956964, -0.37609905};
+
class UnidirectionalRNNOpModel : public SingleOpModel {
public:
- UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size,
- bool time_major)
+ UnidirectionalRNNOpModel(
+ int batches, int sequence_len, int units, int size, bool time_major,
+ const TensorType& weights = TensorType_FLOAT32,
+ const TensorType& recurrent_weights = TensorType_FLOAT32)
: batches_(batches),
sequence_len_(sequence_len),
units_(units),
input_size_(size) {
input_ = AddInput(TensorType_FLOAT32);
- weights_ = AddInput(TensorType_FLOAT32);
- recurrent_weights_ = AddInput(TensorType_FLOAT32);
+ weights_ = AddInput(weights);
+ recurrent_weights_ = AddInput(recurrent_weights);
bias_ = AddInput(TensorType_FLOAT32);
hidden_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -187,7 +236,7 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
int num_batches() { return batches_; }
int sequence_len() { return sequence_len_; }
- private:
+ protected:
int input_;
int weights_;
int recurrent_weights_;
@@ -201,58 +250,31 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
int input_size_;
};
-// TODO(mirkov): add another test which directly compares to TF once TOCO
-// supports the conversion from dynamic_rnn with BasicRNNCell.
-TEST(FullyConnectedOpTest, BlackBoxTest) {
+// The hybrid model has quantized weights and recurrent_weights.
+class HybridUnidirectionalRNNOpModel : public UnidirectionalRNNOpModel {
+ public:
+ HybridUnidirectionalRNNOpModel(int batches, int sequence_len, int units,
+ int size, bool time_major)
+ : UnidirectionalRNNOpModel(batches, sequence_len, units, size, time_major,
+ TensorType_UINT8, TensorType_UINT8) {}
+
+ void SetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(weights_, f);
+ }
+
+ void SetRecurrentWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_weights_, f);
+ }
+};
+
+TEST(UnidirectionalRNNOpTest, BlackBoxTest) {
UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*units=*/16, /*size=*/8, /*time_major=*/false);
- rnn.SetWeights(
- {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
- 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
- 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
- -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
- -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
- -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
- -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
- 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
- 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
- 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
- -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
- 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
- -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
- -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
- 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
- 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
- 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
- -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
- 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
- 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
- -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
- 0.277308, 0.415818});
-
- rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
- -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
- 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
- -0.37609905});
-
- rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1});
-
+ rnn.SetWeights(rnn_weights);
+ rnn.SetBias(rnn_bias);
+ rnn.SetRecurrentWeights(rnn_recurrent_weights);
rnn.ResetHiddenState();
+
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
float* batch_end = batch_start + input_sequence_size;
@@ -270,56 +292,42 @@ TEST(FullyConnectedOpTest, BlackBoxTest) {
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
}
-TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) {
- UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
- /*units=*/16, /*size=*/8, /*time_major=*/true);
- rnn.SetWeights(
- {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
- 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
- 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
- -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
- -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
- -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
- -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
- 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
- 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
- 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
- -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
- 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
- -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
- -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
- 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
- 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
- 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
- -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
- 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
- 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
- -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
- 0.277308, 0.415818});
-
- rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
- -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
- 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
- -0.37609905});
-
- rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1});
+TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTest) {
+ HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*units=*/16, /*size=*/8,
+ /*time_major=*/false);
+ rnn.SetWeights(rnn_weights);
+ rnn.SetBias(rnn_bias);
+ rnn.SetRecurrentWeights(rnn_recurrent_weights);
+ rnn.ResetHiddenState();
+
+ const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
+ float* batch_start = rnn_input;
+ float* batch_end = batch_start + input_sequence_size;
+ rnn.SetInput(0, batch_start, batch_end);
+ rnn.SetInput(input_sequence_size, batch_start, batch_end);
+
+ rnn.Invoke();
+
+ float* golden_start = rnn_golden_output;
+ float* golden_end = golden_start + rnn.num_units() * rnn.sequence_len();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ expected, /*max_abs_error=*/0.013)));
+}
+TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) {
+ UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*units=*/16, /*size=*/8,
+ /*time_major=*/true);
+ rnn.SetWeights(rnn_weights);
+ rnn.SetBias(rnn_bias);
+ rnn.SetRecurrentWeights(rnn_recurrent_weights);
rnn.ResetHiddenState();
+
for (int i = 0; i < rnn.sequence_len(); i++) {
float* batch_start = rnn_input + i * rnn.input_size();
float* batch_end = batch_start + rnn.input_size();
@@ -341,6 +349,37 @@ TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) {
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
}
+TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTest) {
+ HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*units=*/16, /*size=*/8,
+ /*time_major=*/true);
+ rnn.SetWeights(rnn_weights);
+ rnn.SetBias(rnn_bias);
+ rnn.SetRecurrentWeights(rnn_recurrent_weights);
+ rnn.ResetHiddenState();
+
+ for (int i = 0; i < rnn.sequence_len(); i++) {
+ float* batch_start = rnn_input + i * rnn.input_size();
+ float* batch_end = batch_start + rnn.input_size();
+ // The two batches are identical.
+ rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
+ rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
+ }
+
+ rnn.Invoke();
+
+ std::vector<float> expected;
+ for (int i = 0; i < rnn.sequence_len(); i++) {
+ float* golden_batch_start = rnn_golden_output + i * rnn.num_units();
+ float* golden_batch_end = golden_batch_start + rnn.num_units();
+ expected.insert(expected.end(), golden_batch_start, golden_batch_end);
+ expected.insert(expected.end(), golden_batch_start, golden_batch_end);
+ }
+
+ EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ expected, /*max_abs_error=*/0.013)));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index e89036ce73..abbdec23bb 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -184,8 +184,10 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
TfLiteStatus status = kTfLiteOk;
auto opcodes = model_->operator_codes();
for (const OperatorCode* opcode : *opcodes) {
- TfLiteRegistration* registration = nullptr;
+ const TfLiteRegistration* registration = nullptr;
auto builtin_code = opcode->builtin_code();
+ int version = opcode->version();
+
if (builtin_code > BuiltinOperator_MAX ||
builtin_code < BuiltinOperator_MIN) {
error_reporter_->Report(
@@ -194,8 +196,7 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
builtin_code);
status = kTfLiteError;
} else if (builtin_code != BuiltinOperator_CUSTOM) {
- flatbuffer_op_index_to_registration_types_.push_back(builtin_code);
- registration = op_resolver_.FindOp(builtin_code);
+ registration = op_resolver_.FindOp(builtin_code, version);
if (registration == nullptr) {
error_reporter_->Report("Didn't find op for builtin opcode '%s'\n",
EnumNameBuiltinOperator(builtin_code));
@@ -207,11 +208,13 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
status = kTfLiteError;
} else {
const char* name = opcode->custom_code()->c_str();
- registration = op_resolver_.FindOp(name);
+ registration = op_resolver_.FindOp(name, version);
flatbuffer_op_index_to_registration_types_.push_back(
BuiltinOperator_CUSTOM);
if (registration == nullptr) {
- error_reporter_->Report("Didn't find custom op for name '%s'\n", name);
+ error_reporter_->Report(
+ "Didn't find custom op for name '%s' with version %d\n", name,
+ version);
status = kTfLiteError;
}
}
@@ -333,6 +336,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->stride_height = conv_params->stride_h();
params->activation =
parse_activation(conv_params->fused_activation_function());
+
params->dilation_width_factor = conv_params->dilation_w_factor();
params->dilation_height_factor = conv_params->dilation_h_factor();
}
@@ -352,6 +356,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_PRELU:
case BuiltinOperator_FLOOR:
case BuiltinOperator_NEG:
+ case BuiltinOperator_SIN:
break;
case BuiltinOperator_CAST: {
TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
@@ -679,6 +684,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_SELECT: {
break;
}
+ case BuiltinOperator_SLICE: {
+ break;
+ }
case BuiltinOperator_DELEGATE: {
// TODO(ycling): Revisit when supporting saving delegated models.
error_reporter->Report("DELEGATE op shouldn't exist in model.");
@@ -703,27 +711,30 @@ TfLiteStatus InterpreterBuilder::ParseNodes(
status = kTfLiteError;
continue;
}
- const TfLiteRegistration* reg =
+
+ const TfLiteRegistration* registration =
flatbuffer_op_index_to_registration_[op->opcode_index()];
- if (reg == nullptr) {
+ if (registration == nullptr) {
error_reporter_->Report("Skipping op for opcode_index %d\n", index);
status = kTfLiteError;
continue;
}
- auto op_type =
- flatbuffer_op_index_to_registration_types_[op->opcode_index()];
+ BuiltinOperator op_type =
+ static_cast<BuiltinOperator>(registration->builtin_code);
+
if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
error_reporter_->Report(
"Found builtin operator %s with custom options.\n",
EnumNameBuiltinOperator(op_type));
}
+
if (op->custom_options()) {
interpreter->AddNodeWithParameters(
FlatBufferIntArrayToVector(op->inputs()),
FlatBufferIntArrayToVector(op->outputs()),
reinterpret_cast<const char*>(op->custom_options()->data()),
- op->custom_options()->size(), nullptr, reg);
+ op->custom_options()->size(), nullptr, registration);
} else {
void* builtin_data = nullptr;
TF_LITE_ENSURE_STATUS(
@@ -731,7 +742,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes(
interpreter->AddNodeWithParameters(
FlatBufferIntArrayToVector(op->inputs()),
FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, builtin_data,
- reg);
+ registration);
}
}
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
index 5a55b031a8..3946b49041 100644
--- a/tensorflow/contrib/lite/model.h
+++ b/tensorflow/contrib/lite/model.h
@@ -37,6 +37,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
namespace tflite {
@@ -131,18 +132,6 @@ class FlatBufferModel {
Allocation* allocation_ = nullptr;
};
-// Abstract interface that returns TfLiteRegistrations given op codes or custom
-// op names. This is the mechanism that ops being referenced in the flatbuffer
-// model are mapped to executable function pointers (TfLiteRegistrations).
-class OpResolver {
- public:
- // Finds the op registration for a builtin operator by enum code.
- virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0;
- // Finds the op registration of a custom operator by op name.
- virtual TfLiteRegistration* FindOp(const char* op) const = 0;
- virtual ~OpResolver() {}
-};
-
// Build an interpreter capable of interpreting `model`.
//
// model: a scoped model whose lifetime must be at least as long as
@@ -187,7 +176,7 @@ class InterpreterBuilder {
const OpResolver& op_resolver_;
ErrorReporter* error_reporter_;
- std::vector<TfLiteRegistration*> flatbuffer_op_index_to_registration_;
+ std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration_;
std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_;
const Allocation* allocation_ = nullptr;
};
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
index ae6c1ece18..15bae21a41 100644
--- a/tensorflow/contrib/lite/model_test.cc
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -55,11 +55,12 @@ class TrivialResolver : public OpResolver {
explicit TrivialResolver(TfLiteRegistration* constant_return = nullptr)
: constant_return_(constant_return) {}
// Find the op registration of a custom operator by op name.
- TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override {
+ const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const override {
return constant_return_;
}
// Find the op registration of a custom operator by op name.
- TfLiteRegistration* FindOp(const char* op) const override {
+ const TfLiteRegistration* FindOp(const char* op, int version) const override {
return constant_return_;
}
diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD
index a82d1f2eb6..8b5fa240ac 100644
--- a/tensorflow/contrib/lite/models/smartreply/BUILD
+++ b/tensorflow/contrib/lite/models/smartreply/BUILD
@@ -22,7 +22,6 @@ cc_library(
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/kernels:builtin_ops",
- "//tensorflow/contrib/lite/tools:mutable_op_resolver",
"@com_google_absl//absl/strings",
"@com_googlesource_code_re2//:re2",
"@farmhash_archive//:farmhash",
@@ -39,7 +38,6 @@ cc_library(
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/kernels:builtin_ops",
- "//tensorflow/contrib/lite/tools:mutable_op_resolver",
"@com_google_absl//absl/strings",
"@com_googlesource_code_re2//:re2",
],
diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc
index f97a6486d6..29c8ad2286 100644
--- a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc
+++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc
@@ -61,7 +61,7 @@ bool IsValidNgram(const tflite::StringRef& strref) {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArray* outputSize1 = TfLiteIntArrayCreate(1);
TfLiteIntArray* outputSize2 = TfLiteIntArrayCreate(1);
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
int dim = input->dims->data[0];
if (dim == 0) {
// TFLite non-string output should have size greater than 0.
@@ -76,7 +76,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* input = GetInput(context, node, 0);
+ const TfLiteTensor* input = GetInput(context, node, 0);
int num_strings = tflite::GetStringCount(input);
TfLiteTensor* label = GetOutput(context, node, 0);
TfLiteTensor* weight = GetOutput(context, node, 1);
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.cc b/tensorflow/contrib/lite/models/smartreply/predictor.cc
index 6da5cc8eec..ceef8e6a29 100644
--- a/tensorflow/contrib/lite/models/smartreply/predictor.cc
+++ b/tensorflow/contrib/lite/models/smartreply/predictor.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
#include "tensorflow/contrib/lite/string_util.h"
-#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
index 4a648e4283..becd1f615f 100644
--- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
+++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
@@ -65,7 +65,8 @@ inline bool NNAPIExists() {
return nnapi_is_available;
}
-// nn api types
+// NN api types based on NNAPI header file
+// https://developer.android.com/ndk/reference/group/neural-networks
/**
* Operand types.
@@ -77,31 +78,11 @@ inline bool NNAPIExists() {
* ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, and ANEURALNETWORKS_INT32.
*/
enum {
- /** The following entries are used to declare scalars. */
-
- /** A 32 bit floating point scalar value. */
ANEURALNETWORKS_FLOAT32 = 0,
- /** A signed 32 bit integer scalar value. */
ANEURALNETWORKS_INT32 = 1,
- /** An unsigned 32 bit integer scalar value. */
ANEURALNETWORKS_UINT32 = 2,
-
- /** The following entries are used to declare tensors. */
-
- /** A tensor of 32 bit floating point values. */
ANEURALNETWORKS_TENSOR_FLOAT32 = 3,
- /** A tensor of 32 bit integer values. */
ANEURALNETWORKS_TENSOR_INT32 = 4,
- /** A tensor of 8 bit integers that represent real numbers.
- *
- * Attached to this tensor are two numbers that can be used to convert
- * the 8 bit integer to the real value and vice versa. These two numbers are:
- * - scale: a 32 bit floating point value
- * - zero_value: an 32 bit integer
- *
- * The formula is:
- * real_value = (integer_value - zero_value) * scale.
- */
ANEURALNETWORKS_TENSOR_QUANT8_ASYMM = 5,
};
@@ -111,968 +92,44 @@ enum {
* The type of operations that can be added to a model.
*/
enum {
- /** Adds two tensors, element-wise.
- *
- * Takes two input tensors of identical type and compatible dimensions. The
- * output is the sum of both input tensors, optionally modified by an
- * activation function.
- *
- * Two dimensions are compatible when:
- * 1. they are equal, or
- * 2. one of them is 1
- *
- * The size of the output is the maximum size along each dimension of the
- * input operands. It starts with the trailing dimensions, and works its way
- * forward.
- *
- * Example:
- *
- * input1.dimension = {4, 1, 2}
- * input2.dimension = {5, 4, 3, 1}
- * output.dimension = {5, 4, 3, 2}
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- *
- * Supported tensor rank: up to 4
- *
- * Inputs:
- * * 0: A tensor.
- * * 1: A tensor of the same type, and compatible dimensions as input0.
- * * 2: An INT32 value, and has to be one of the {@link FuseCode} values.
- * Specifies the activation to invoke on the result of each addition.
- *
- * Outputs:
- * * 0: The sum, a tensor of the same type as input0.
- */
ANEURALNETWORKS_ADD = 0,
- /** Performs a 2-D average pooling operation.
- *
- * The output dimensions are functions of the filter dimensions, stride, and
- * padding.
- *
- * The values in the output tensor are computed as:
- *
- * output[batch, row, col, channel] =
- * sum_{i, j}(input[batch, row + i, col + j, channel]) / sum(1)
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: 4, with "NHWC" data layout.
- *
- * Inputs:
- * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
- * input.
- * * 1: An INT32 value, specifying the padding on the left, in the ‘width’
- * dimension.
- * * 2: An INT32 value, specifying the padding on the right,in the ‘width’
- * dimension.
- * * 3: An INT32 value, specifying the padding on the top, in the ‘height’
- * dimension.
- * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’
- * dimension.
- * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension.
- * * 6: An INT32 value, specifying the output stride in the ‘height’
- * dimension.
- * * 7: An INT32 value, specifying the filter width.
- * * 8: An INT32 value, specifying the filter height.
- * * 9: An INT32 value, and has to be one of the {@link FuseCode} values.
- * Specifies the activation to invoke on the result of each addition.
- *
- * Outputs:
- * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
- * depth].
- */
ANEURALNETWORKS_AVERAGE_POOL_2D = 1,
- /** Concatenates the input tensors along the given dimension.
- *
- * The input tensors must have identical type and the same dimensions except
- * the dimension along the concatenation axis.
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: up to 4
- *
- * Inputs:
- * 0 ~ n: The list on n input tensors, of shape [D0, D1, ..., Daxis(i), ...,
- * Dm] n+1: An INT32 value, specifying the concatenation axis. n+2: An INT32
- * value, and has to be one of the {@link FuseCode} values. Specifies the
- * activation to invoke on the result of each addition.
- *
- * Outputs:
- * * 0: The output, a tensor of the same type as the input tensors.
- * The output shape is [D0, D1, ..., sum(Daxis(i)), ..., Dm].
- */
ANEURALNETWORKS_CONCATENATION = 2,
- /** Performs an 2-D convolution operation.
- *
- * The CONV_2D op sweeps a 2-D filter that can mix channels together over a
- * batch of images, applying the filter to each window of each image of the
- * appropriate size.
- *
- * The output dimensions are functions of the filter dimensions, stride, and
- * padding.
- *
- * The values in the output tensor are computed as:
- *
- * output[batch, row, col, channel] =
- * sum_{i, j} (
- * input[batch, row + i, col + j, k] *
- * filter[channel, row + i, col + j, k] +
- * bias[channel]
- * )
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: 4, with "NHWC" data layout.
- *
- * Inputs:
- * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying
- * the input.
- * * 1: A 4-D tensor, of shape [depth_out, filter_height, filter_width,
- * depth_in], specifying the filter.
- * * 2: A 1-D tensor, of shape [depth_out], specifying the bias.
- * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the
- * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input
- * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should
- * be of {@link ANEURALNETWORKS_TENSOR_INT32}.
- * * 3: An INT32 value, specifying the padding on the left, in the ‘width’
- * dimension.
- * * 4: An INT32 value, specifying the padding on the right,in the ‘width’
- * dimension.
- * * 5: An INT32 value, specifying the padding on the top, in the ‘height’
- * dimension.
- * * 6: An INT32 value, specifying the padding on the bottom, in the ‘height’
- * dimension.
- * * 7: An INT32 value, specifying the output stride in the ‘width’ dimension.
- * * 8: An INT32 value, specifying the output stride in the ‘height’
- * dimension.
- * * 9: An INT32 value, and has to be one of the {@link FuseCode} values.
- * Specifies the activation to invoke on the result of each addition.
- *
- * Outputs:
- * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
- * depth_out].
- */
ANEURALNETWORKS_CONV_2D = 3,
- /** Performs a depthwise 2-D convolution operation.
- *
- * Given an input tensor of shape [batches, height, width, depth_in] and a
- * filter tensor of shape [depth_out, filter_height, filter_width, depth_in]
- * containing in_channels convolutional filters of depth 1, DEPTHWISE_CONV
- * applies a different filter to each input channel (expanding from 1 channel
- * to channel_multiplier channels for each), then concatenates the results
- * together.
- *
- * The output has depth_out = depth_in * depth_multiplier channels.
- * The output dimensions are functions of the filter dimensions, stride, and
- * padding.
- *
- * The values in the output tensor are computed as:
- *
- * output[b, i, j, k * channel_multiplier + q] =
- * sum_{di, dj} (
- * input[b, strides[1] * i + di, strides[2] * j + dj, k] *
- * filter[di, dj, k, q]
- * )
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: 4, with "NHWC" data layout.
- *
- * Inputs:
- * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying
- * the input.
- * * 1: A 4-D tensor, of shape [depth_out, filter_height, filter_width,
- * depth_in], specifying the filter.
- * * 2: A 1-D tensor, of shape [depth_out], specifying the bias.
- * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the
- * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input
- * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should
- * be of {@link ANEURALNETWORKS_TENSOR_INT32}.
- * * 3: An INT32 value, specifying the padding on the left, in the ‘width’
- * dimension.
- * * 4: An INT32 value, specifying the padding on the right,in the ‘width’
- * dimension.
- * * 5: An INT32 value, specifying the padding on the top, in the ‘height’
- * dimension.
- * * 6: An INT32 value, specifying the padding on the bottom, in the ‘height’
- * dimension.
- * * 7: An INT32 value, specifying the output stride in the ‘width’ dimension.
- * * 8: An INT32 value, specifying the output stride in the ‘height’
- * dimension.
- * * 9: An INT32 value, specifying the depthwise multiplier.
- * * 10: An INT32 value, and has to be one of the {@link FuseCode} values.
- * Specifies the activation to invoke on the result of each addition.
- *
- * Outputs:
- * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
- * depth_out].
- */
ANEURALNETWORKS_DEPTHWISE_CONV_2D = 4,
- /** Rearranges data from depth into blocks of spatial data.
- *
- * More specifically, this op outputs a copy of the input tensor where values
- * from the depth dimension are moved in spatial blocks to the height and
- * width dimensions. The value block_size indicates the input block size and
- * how the data is moved.
- *
- * Chunks of data of size block_size * block_size from depth are rearranged
- * into non-overlapping blocks of size block_size x block_size.
- *
- * The width of the output tensor is input_depth * block_size, whereas the
- * height is input_height * block_size. The depth of the input tensor must be
- * divisible by block_size * block_size
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: 4, with "NHWC" data layout.
- *
- * Inputs:
- * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying
- * the input.
- * * 1: An INT32 value, specifying the block_size. block_size must be >=1 and
- * block_size * block_size must be a divisor of the input depth.
- *
- * Outputs:
- * * 0: The output 4-D tensor, of shape [batch, height*block_size,
- * width*block_size, depth/(block_size*block_size)].
- */
ANEURALNETWORKS_DEPTH_TO_SPACE = 5,
- /** Dequantizes the input tensor.
- *
- * The formula is:
- *
- * output = (input - zero_value) * scale.
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: up to 4
- *
- * Inputs:
- * * 0: A tensor of type {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}.
- *
- * Outputs:
- * * 0: The output tensor of same shape as input0, but with type
- * {@link ANEURALNETWORKS_TENSOR_FLOAT32}.
- */
ANEURALNETWORKS_DEQUANTIZE = 6,
-
- /**
- * Looks up items from a given tensor.
- *
- * Each item in the output is a raw copy of the corresponding item in
- * the input “values”. If the given “lookup” indices are out of bounds,
- * the op will fail and an error will be reported.
- *
- * Inputs:
- * * 0: Values. An n-D tensor of any type X (where n >= 2). E.g., if n is 2,
- * then the shape would be [lookup_dimension, values_dimension], where
- * “lookup_dimension” corresponds to the indexing dimension in the lookup
- * table, and “values_dimension” to the contents.
- * * 1: Lookups. An 1-D tensor of type T, of shape [lookup_size], where
- * “lookup_size” is the number of elements to look for, and each entry
- * corresponds to the first dimension of the “values” tensor.
- *
- * Output:
- * * 0: A n-D tensor of type X and the same rank and shape as the “values”
- * tensor, except for the first dimension which has size “lookup_size”.
- */
ANEURALNETWORKS_EMBEDDING_LOOKUP = 7,
-
- /** Computes element-wise floor() on the input tensor.
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- *
- * Supported tensor rank: up to 4
- *
- * Inputs:
- * * 0: A tensor.
- *
- * Outputs:
- * * 0: The output, a tensor of the same type and dimensions as input0.
- */
ANEURALNETWORKS_FLOOR = 8,
- /** Denotes a fully (densely) connected layer, which connects all elements in
- * the input tensor with each element in the output tensor.
- *
- * This layer implements the operation:
- *
- * outputs = activation(inputs * weights’ + bias)
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: up to 4.
- *
- * Inputs:
- * * 0: A tensor, specifying the input. If rank is greater than 2, then it
- * gets flattened to a 2-D Tensor. The 2-D Tensor is handled as if dimensions
- * corresponded to shape [batch_size, input_size], where “batch_size”
- * corresponds to the batching dimension, and “input_size” is the size of the
- * input.
- * * 1: A 2-D tensor, specifying the weights, of shape [num_units,
- * input_size], where "num_units" corresponds to the number of output nodes.
- * * 2: A 1-D tensor, of shape [num_units], specifying the bias.
- * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the
- * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input
- * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should
- * be of {@link ANEURALNETWORKS_TENSOR_INT32}.
- * * 3: An INT32 value, and has to be one of the {@link FuseCode} values.
- * Specifies the activation to invoke on the result of each addition.
- *
- * Outputs:
- * * 0: The output tensor, of shape [batch_size, num_units].
- */
ANEURALNETWORKS_FULLY_CONNECTED = 9,
-
- /**
- * Looks up values of a hash table with given keys.
- *
- * Inputs:
- * * 0: Lookups. A 1-D int32 tensor with shape [ k ].
- * * 1: Keys. A 1-D int32 tensor with shape [ n ], *MUST* be sorted in
- * ascending order.
- * * 2: Values. A tensor with shape [ n … ].
- *
- * Outputs:
- * * 0: Output. A tensor with shape [ k …].
- * * 1: Hits. A uint8 tensor with shape [ k ] indicates whether the lookup
- * hits or not.
- */
ANEURALNETWORKS_HASHTABLE_LOOKUP = 10,
-
- /** Applies L2 normalization along the depth dimension.
- *
- * The values in the output tensor are computed as:
- *
- * output[batch, row, col, channel] =
- * input[batch, row, col, channel] /
- * sqrt(sum_{c} pow(input[batch, row, col, c], 2))
- *
- * For x with more dimensions, independently normalizes each 1-D slice along
- * dimension dim.
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- *
- * Supported tensor rank: 4, with "NHWC" data layout.
- *
- * Inputs:
- * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
- * input.
- *
- * Outputs:
- * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
- * depth].
- */
ANEURALNETWORKS_L2_NORMALIZATION = 11,
-
- /** Performs an 2-D L2 pooling operation.
- *
- * The output dimensions are functions of the filter dimensions, stride, and
- * padding.
- *
- * The values in the output tensor are computed as:
- *
- * output[batch, row, col, channel] =
- * sqrt(sum_{i, j} pow(input[batch, row + i, col + j, channel], 2) /
- * sum(1))
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- *
- * Supported tensor rank: 4, with "NHWC" data layout.
- *
- * Inputs:
- * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
- * input.
- * * 1: An INT32 value, specifying the padding on the left, in the ‘width’
- * dimension.
- * * 2: An INT32 value, specifying the padding on the right,in the ‘width’
- * dimension.
- * * 3: An INT32 value, specifying the padding on the top, in the ‘height’
- * dimension.
- * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’
- * dimension.
- * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension.
- * * 6: An INT32 value, specifying the output stride in the ‘height’
- * dimension.
- * * 7: An INT32 value, specifying the filter width.
- * * 8: An INT32 value, specifying the filter height.
- * * 9: An INT32 value, and has to be one of the {@link FuseCode} values.
- * Specifies the activation to invoke on the result of each addition.
- *
- * Outputs:
- * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
- * depth].
- */
ANEURALNETWORKS_L2_POOL_2D = 12,
- /** Applies Local Response Normalization along the depth dimension.
- *
- * The 4-D input tensor is treated as a 3-D array of 1-D vectors (along the
- * last dimension), and each vector is normalized independently. Within a
- * given vector, each component is divided by the weighted, squared sum of
- * inputs within depth_radius.
- *
- * The output is calculated using this formula:
- *
- * sqr_sum[a, b, c, d] =
- * sum(pow(input[a, b, c, d - depth_radius : d + depth_radius + 1], 2)
- * output = input / pow((bias + alpha * sqr_sum), beta)
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- *
- * Supported tensor rank: 4, with "NHWC" data layout.
- *
- * Inputs:
- * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
- * input.
- * * 1: An INT32 value, specifying the radius of the normalization window.
- * * 2: A FLOAT32 value, specifying the bias, must not be zero.
- * * 3: A FLOAT32 value, specifying the scale factor, alpha.
- * * 4: A FLOAT32 value, specifying the exponent, beta.
- *
- * Outputs:
- * * 0: The output tensor of same shape as input0.
- */
ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION = 13,
- /** Computes sigmoid activation on the input tensor element-wise.
- *
- * The output is calculated using this formula:
- *
- * output = 1 / (1 + exp(-input))
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: up to 4.
- *
- * Inputs:
- * * 0: A tensor, specifying the input.
- *
- * Outputs:
- * * 0: The output tensor of same shape as input0.
- */
ANEURALNETWORKS_LOGISTIC = 14,
-
- /**
- * Projects an input to a bit vector via locality sensitive hashing.
- *
- * Inputs:
- * * 0: Hash functions. Dim.size == 2, DataType: Float.
- * Tensor[0].Dim[0]: Number of hash functions.
- * Tensor[0].Dim[1]: Number of seeds per hash functions.
- * Tensor[0].Dim[1] <= 32 in sparse case.
- *
- * * 1: Input. Dim.size >= 1, no restriction on DataType.
- * * 2: Weight. Optional. Dim.size == 1, DataType: Float.
- * If not set, each input element is considered to have the same weight of
- * 1.0.
- * Tensor[1].Dim[0] == Tensor[2].Dim[0]
- * * 3: Type:
- * Sparse: Value LSHProjectionType_SPARSE(=1).
- * Computed bit vector is considered to be sparse.
- * Each output element is an int32 made up of multiple bits computed
- * from hash functions.
- *
- * Dense: Value LSHProjectionType_DENSE(=2).
- * Computed bit vector is considered to be dense. Each output element
- * represents a bit and can take the value of either 0 or 1.
- *
- * Outputs:
- * * 0: If the projection type is sparse:
- * Output.Dim == { Tensor[0].Dim[0] }
- * A tensor of int32 that represents hash signatures.
- * If the projection type is Dense:
- * Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] }
- * A flattened tensor that represents projected bit vectors.
- */
ANEURALNETWORKS_LSH_PROJECTION = 15,
-
- /**
- * Long short-term memory unit (LSTM) recurrent network layer.
- *
- * The default non-peephole implementation is based on:
- * http://www.bioinf.jku.at/publications/older/2604.pdf
- * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural
- * Computation, 9(8):1735-1780, 1997.
- *
- * The peephole implementation is based on:
- * https://research.google.com/pubs/archive/43905.pdf
- * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory
- * recurrent neural network architectures for large scale acoustic modeling."
- * INTERSPEECH, 2014.
- *
- * The coupling of input and forget gate (CIFG) is based on:
- * http://arxiv.org/pdf/1503.04069.pdf
- * Greff et al. "LSTM: A Search Space Odyssey"
- *
- * The class has the following independently optional inputs:
- * * If input gate (if CIFG): “input_to_forget_weights”,
- * “recurrent_to_input_weights”, “cell_to_input_weights”, “input_gate_bias”.
- * * If no peephole connections: “cell_to_input_weights”,
- * “cell_to_forget_weights”, “cell_to_output_weights”.
- * * If no projection layer: “projection_weights” and “projection_bias”.
- * * If no projection bias: “projection_bias”.
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- *
- * Inputs:
- * * 0: Input.
- * A 2-D tensor of type T, of shape [batch_size, input_size], where
- * “batch_size” corresponds to the batching dimension, and “input_size”
- * is the size of the input.
- * * 1: input_to_input_weights.
- * A 2-D tensor of type T, of shape [num_units, input_size], where
- * “num_units” corresponds to the number of cell units.
- * * 2: input_to_forget_weights.
- * A 2-D tensor of type T, of shape [num_units, input_size].
- * * 3: input_to_cell_weights.
- * A 2-D tensor of type T, of shape [num_units, input_size].
- * * 4: input_to_output_weights.
- * A 2-D tensor of type T, of shape [num_units, input_size].
- * * 5: recurrent_to_input_weights.
- * A 2-D tensor of type T, of shape [num_units, output_size], where
- * “output_size” corresponds to either the number of cell units (i.e.,
- * “num_units”), or the second dimension of the “projection_weights”, if
- * defined.
- * * 6: recurrent_to_forget_weights.
- * A 2-D tensor of type T, of shape [num_units, output_size].
- * * 7: recurrent_to_cell_weights.
- * A 2-D tensor of type T, of shape [num_units, output_size].
- * * 8: recurrent_to_output_weights.
- * A 2-D tensor of type T, of shape [num_units, output_size].
- * * 9: cell_to_input_weights.
- * A 1-D tensor of type T, of shape [num_units].
- * * 10:cell_to_forget_weights.
- * A 1-D tensor of type T, of shape [num_units].
- * * 11:cell_to_output_weights.
- * A 1-D tensor of type T, of shape [num_units].
- * * 12:input_gate_bias.
- * A 1-D tensor of type T, of shape [num_units].
- * * 13:forget_gate_bias.
- * A 1-D tensor of type T, of shape [num_units].
- * * 14:cell_bias.
- * A 1-D tensor of type T, of shape [num_units].
- * * 15:output_gate_bias.
- * A 1-D tensor of type T, of shape [num_units].
- * * 16:projection_weights.
- * A 2-D tensor of type T, of shape [output_size, num_units].
- * * 17:projection_bias.
- * A 1-D tensor of type T, of shape [output_size].
- *
- * Parameters:
- * * 18:fused_activation_function.
- * An (optional) ActivationFunctionType indicating the activation
- * function.
- * If “NONE” is specified then it results in a linear activation.
- * * 19:cell_clip.
- * A clipping threshold for the cell state, such that values are bound
- * within [-cell_clip, cell_clip]. If set to 0.0 then clipping is
- * disabled.
- * * 20:proj_clip.
- * A clipping threshold for the output from the projection layer, such
- * that values are bound within [-proj_clip, proj_clip]. If set to 0.0
- * then clipping is disabled.
- *
- * Outputs:
- * * 0: scratch_buffer.
- * A 3-D tensor of type T, of shape [batch_size, num_cell, 4].
- * * 1: output_state.
- * A 2-D tensor of type T, of shape [batch_size, output_size].
- * * 2: cell_state.
- * A 2-D tensor of type T, of shape [batch_size, num_units].
- * * 3: output.
- * A 2-D tensor of type T, of shape [batch_size, output_size]. This is
- * effectively the same as the current “output_state” value.
- */
ANEURALNETWORKS_LSTM = 16,
-
- /** Performs an 2-D max pooling operation.
- *
- * The output dimensions are functions of the filter dimensions, stride, and
- * padding.
- *
- * The values in the output tensor are computed as:
- *
- * output[batch, row, col, channel] =
- * max_{i, j} (input[batch, row + i, col + j, channel])
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: 4, with "NHWC" data layout.
- *
- * Inputs:
- * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
- * input.
- * * 1: An INT32 value, specifying the padding on the left, in the ‘width’
- * dimension.
- * * 2: An INT32 value, specifying the padding on the right,in the ‘width’
- * dimension.
- * * 3: An INT32 value, specifying the padding on the top, in the ‘height’
- * dimension.
- * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’
- * dimension.
- * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension.
- * * 6: An INT32 value, specifying the output stride in the ‘height’
- * dimension.
- * * 7: An INT32 value, specifying the filter width.
- * * 8: An INT32 value, specifying the filter height.
- * * 9: An INT32 value, and has to be one of the {@link FuseCode} values.
- * Specifies the activation to invoke on the result of each addition.
- *
- * Outputs:
- * * 0: The output 4-D tensor, of shape [batches, out_height, out_width,
- * depth].
- */
ANEURALNETWORKS_MAX_POOL_2D = 17,
-
- /** Multiplies two tensors, element-wise.
- *
- * Takes two input tensors of identical type and compatible dimensions. The
- * output is the product of both input tensors, optionally modified by an
- * activation function.
- *
- * Two dimensions are compatible when:
- * 1. they are equal, or
- * 2. one of them is 1
- *
- * The size of the resulting output is the maximum size along each dimension
- * of the input operands. It starts with the trailing dimensions, and works
- * its way forward.
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- *
- * Supported tensor rank: up to 4
- *
- * Inputs:
- * * 0: A tensor.
- * * 1: A tensor of the same type, and compatible dimensions as input0.
- * * 2: An INT32 value, and has to be one of the {@link FuseCode} values.
- * Specifies the activation to invoke on the result of each addition.
- *
- * Outputs:
- * * 0: The product, a tensor of the same type as input0.
- */
ANEURALNETWORKS_MUL = 18,
- /** Computes rectified linear activation on the input tensor element-wise.
- *
- * The output is calculated using this formula:
- *
- * output = max(0, input)
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: up to 4.
- *
- * Inputs:
- * * 0: A tensor, specifying the input.
- *
- * Outputs:
- * * 0: The output tensor of same shape as input0.
- */
ANEURALNETWORKS_RELU = 19,
- /** Computes rectified linear 1 activation on the input tensor element-wise.
- *
- * The output is calculated using this formula:
- *
- * output = min(1.f, max(-1.f, input))
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: up to 4.
- *
- * Inputs:
- * * 0: A tensor, specifying the input.
- *
- * Outputs:
- * * 0: The output tensor of same shape as input0.
- */
ANEURALNETWORKS_RELU1 = 20,
- /** Computes rectified linear 6 activation on the input tensor element-wise.
- *
- * The output is calculated using this formula:
- *
- * output = min(6, max(0, input))
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: up to 4.
- *
- * Inputs:
- * * 0: A tensor, specifying the input.
- *
- * Outputs:
- * * 0: The output tensor of same shape as input0.
- */
ANEURALNETWORKS_RELU6 = 21,
- /** Reshapes a tensor.
- *
- * Given tensor, this operation returns a tensor that has the same values as
- * tensor, but with a newly specified shape.
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: up to 4.
- *
- * Inputs:
- * * 0: A tensor, specifying the tensor to be reshaped.
- * * 1: A 1-D tensor of type {@link ANEURALNETWORKS_TENSOR_INT32}, defining
- * the shape of the output tensor. The number of elements implied by shape
- * must be the same as the number of elements in the input tensor.
- *
- * Outputs:
- * * 0: The output tensor, of shape specified by the input shape.
- */
ANEURALNETWORKS_RESHAPE = 22,
- /** Resizes images to given size using the bilinear interpretation.
- *
- * Resized images will be distorted if their original aspect ratio is not the
- * same as input.
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- *
- * Supported tensor rank: 4, with "NHWC" data layout.
- *
- * Inputs:
- * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the
- * input.
- * * 1: An INT32 value, specifying the output width of the output tensor.
- * * 2: An INT32 value, specifying the output height of the output tensor.
- *
- * Outputs:
- * * 0: The output 4-D tensor, of shape [batches, new_height, new_width,
- * depth].
- */
ANEURALNETWORKS_RESIZE_BILINEAR = 23,
-
- /**
- * A basic recurrent neural network layer.
- *
- * This layer implements the operation:
- * outputs = state = activation(inputs * input_weights + state *
- * recurrent_weights + bias)
- *
- * Where:
- * * “input_weights” is a weight matrix that multiplies the inputs;
- * * “recurrent_weights” is a weight matrix that multiplies the current
- * “state” which itself is the output from the previous time step
- * computation;
- * * “bias” is a bias vector (added to each output vector in the batch);
- * * “activation” is the function passed as the “fused_activation_function”
- * argument (if not “NONE”).
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- *
- * Inputs:
- * * 0: input.
- * A 2-D tensor of type T, of shape [batch_size, input_size], where
- * “batch_size” corresponds to the batching dimension, and “input_size”
- * is the size of the input.
- * * 1: weights.
- * A 2-D tensor of type T, of shape [num_units, input_size], where
- * “num_units” corresponds to the number of units.
- * * 2: recurrent_weights.
- * A 2-D tensor of type T, of shape [num_units, num_units], with columns
- * corresponding to the weights from each unit.
- * * 3: bias.
- * A 1-D tensor of type T, of shape [num_units].
- *
- * For FLOAT32 input tensor, bias must also be FLOAT32.
- * For UINT8 input tensor, bias must be INT32.
- *
- * Parameters
- * * 4: fused_activation_function.
- * An (optional) ActivationFunctionType indicating the activation
- * function. If “NONE” is specified then it results in a linear
- * activation.
- *
- * * 5: Hidden state.
- * A 2-D tensor of type T, of shape [batch_size, num_units].
- *
- * Outputs:
- * * 0: output.
- * A 2-D tensor of type T, of shape [batch_size, num_units]. This is
- * effectively the same as the current state value.
- */
ANEURALNETWORKS_RNN = 24,
-
- /** Computes the softmax activation on the input tensor element-wise, per
- * batch, by normalizing the input vector so the maximum coefficient is zero.
- *
- * The output is calculated using this formula:
- *
- * output[batch, i] =
- * exp((input[batch, i] - max(input[batch, :])) * beta) /
- * sum_{k}{exp((input[batch, k] - max(input[batch, :])) * beta)}
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: 2 or 4.
- *
- * Inputs:
- * * 0: A 2-D or 4-D tensor, specifying the tensor to be reshaped.
- * * 1: A FLOAT32 value, specifying the scaling factor for the exponent, beta.
- *
- * Outputs:
- * * 0: The output tensor of same shape as input0.
- */
ANEURALNETWORKS_SOFTMAX = 25,
-
- /** Rearranges blocks of spatial data, into depth.
- *
- * More specifically, this op outputs a copy of the input tensor where values
- * from the height and width dimensions are moved to the depth dimension. The
- * value block_size indicates the input block size and how the data is moved.
- *
- * Chunks of data of size block_size * block_size from depth are rearranged
- * into non-overlapping blocks of size block_size x block_size.
- *
- * The depth of the output tensor is input_depth * block_size * block_size.
- * The input tensor's height and width must be divisible by block_size.
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}
- *
- * Supported tensor rank: 4, with "NHWC" data layout.
- *
- * Inputs:
- * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying
- * the input.
- * * 1: An INT32 value, specifying the block_size. block_size must be >=1 and
- * block_size must be a divisor of both the input height and width.
- *
- * Outputs:
- * * 0: The output 4-D tensor, of shape [batch, height/block_size,
- * width/block_size, depth*block_size*block_size].
- */
ANEURALNETWORKS_SPACE_TO_DEPTH = 26,
-
- /**
- * SVDF op is a kind of stateful layer derived from the notion that a
- * densely connected layer that's processing a sequence of input frames can
- * be approximated by using a singular value decomposition of each of its
- * nodes. The implementation is based on:
- *
- * https://research.google.com/pubs/archive/43813.pdf
- *
- * P. Nakkiran, R. Alvarez, R. Prabhavalkar, C. Parada.
- * “Compressing Deep Neural Networks using a Rank-Constrained Topology”.
- * INTERSPEECH, 2015.
- *
- * It processes the incoming input using a 2-stage filtering mechanism:
- * * stage 1 performs filtering on the "features" dimension, whose outputs get
- * pushed into a memory of fixed-size memory_size.
- * * stage 2 performs filtering on the "time" dimension of the memory_size
- * memoized outputs of stage 1.
- *
- * Specifically, for rank 1, this layer implements the operation:
- *
- * memory = push(conv1d(inputs, weights_feature, feature_dim, "VALID"));
- * outputs = activation(memory * weights_time + bias);
- *
- * Where:
- * * “weights_feature” is a weights matrix that processes the inputs (by
- * convolving the input with every “feature filter”), and whose outputs get
- * pushed, stacked in order, into the fixed-size “memory” (the oldest entry
- * gets dropped);
- * * “weights_time” is a weights matrix that processes the “memory” (by a
- * batched matrix multiplication on the num_units);
- * * “bias” is an optional bias vector (added to each output vector in the
- * batch); and
- * * “activation” is the function passed as the “fused_activation_function”
- * argument (if not “NONE”).
- *
- * Each rank adds a dimension to the weights matrices by means of stacking
- * the filters.
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- *
- * Inputs:
- * * 0: input.
- * A 2-D tensor of type T, of shape [batch_size, input_size], where
- * “batch_size” corresponds to the batching dimension, and “input_size”
- * is the size of the input.
- * * 1: weights_feature.
- * A 2-D tensor of type T, of shape [num_units, input_size], where
- * “num_units” corresponds to the number of units.
- * * 2: weights_time.
- * A 2-D tensor of type T, of shape [num_units, memory_size], where
- * “memory_size” corresponds to the fixed-size of the memory.
- * * 3: bias.
- * A optional 1-D tensor of type T, of shape [num_units].
- *
- * For FLOAT32 input tensor, bias must also be FLOAT32.
- * For UINT8 input tensor, bias must be INT32.
- *
- * Parameters:
- * * 4: rank.
- * The rank of the SVD approximation.
- * * 5: fused_activation_function.
- * An (optional) ActivationFunctionType indicating the activation
- * function. If “NONE” is specified then it results in a linear activation.
- *
- * Outputs:
- * * 0: state.
- * A 2-D tensor of type T, of shape [batch_size, (memory_size - 1) *
- * num_units * rank].
- * * 1: output.
- * A 2-D tensor of type T, of shape [batch_size, num_units].
- */
ANEURALNETWORKS_SVDF = 27,
-
- /** Computes hyperbolic tangent of input tensor element-wise.
- *
- * The output is calculated using this formula:
- *
- * output = tanh(input)
- *
- * Supported tensor types:
- * * {@link ANEURALNETWORKS_TENSOR_FLOAT32}
- *
- * Supported tensor rank: up to 4.
- *
- * Inputs:
- * * 0: A tensor, specifying the input.
- *
- * Outputs:
- * * 0: The output tensor of same shape as input0.
- */
ANEURALNETWORKS_TANH = 28,
+ ANEURALNETWORKS_BATCH_TO_SPACE_ND = 29,
+ ANEURALNETWORKS_DIV = 30,
+ ANEURALNETWORKS_MEAN = 31,
+ ANEURALNETWORKS_PAD = 32,
+ ANEURALNETWORKS_SPACE_TO_BATCH_ND = 33,
+ ANEURALNETWORKS_SQUEEZE = 34,
+ ANEURALNETWORKS_STRIDED_SLICE = 35,
+ ANEURALNETWORKS_SUB = 36,
+ ANEURALNETWORKS_TRANSPOSE = 37,
};
/**
@@ -1080,13 +137,9 @@ enum {
*
*/
enum {
- /** NO fused activation function. */
ANEURALNETWORKS_FUSED_NONE = 0,
- /** Fused ReLU activation function. */
ANEURALNETWORKS_FUSED_RELU = 1,
- /** Fused ReLU1 activation function. */
ANEURALNETWORKS_FUSED_RELU1 = 2,
- /** Fused ReLU6 activation function. */
ANEURALNETWORKS_FUSED_RELU6 = 3,
};
@@ -1094,20 +147,8 @@ enum {
* Execution preferences.
*/
enum {
- /**
- * Prefer executing in a way that minimizes battery drain.
- * This is desirable for compilations that will be executed often.
- */
ANEURALNETWORKS_PREFER_LOW_POWER = 0,
- /**
- * Prefer returning a single answer as fast as possible, even if this causes
- * more power consumption.
- */
ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1,
- /**
- * Prefer maximizing the throughput of successive frames, for example when
- * processing successive frames coming from the camera.
- */
ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2,
};
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index eb451397bd..d99c88a26d 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -23,6 +23,10 @@ limitations under the License.
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
+#ifdef __ANDROID__
+#include <sys/system_properties.h>
+#endif
+
namespace tflite {
// TODO(aselle): FATAL leaves resources hanging.
@@ -46,6 +50,32 @@ void FATAL(const char* format, ...) {
FATAL("Aborting since tflite returned failure."); \
}
+namespace {
+
+int32_t GetAndroidSdkVersion() {
+#ifdef __ANDROID__
+ const char* sdkProp = "ro.build.version.sdk";
+ char sdkVersion[PROP_VALUE_MAX];
+ int length = __system_property_get(sdkProp, sdkVersion);
+ if (length != 0) {
+ for (int i = 0; i < length; ++i) {
+ int digit = sdkVersion[i] - '0';
+ if (digit < 0 || digit > 9) {
+ // Non-numeric SDK version, assume it's higher then expected;
+ return 0xFFFF;
+ }
+ }
+ return atoi(sdkVersion);
+ }
+ FATAL("No %s prop", sdkProp);
+#endif // __ANDROID__
+ return 0;
+}
+
+static const int32_t kAndroidSdkVersion = GetAndroidSdkVersion();
+
+} // namespace
+
NNAPIAllocation::NNAPIAllocation(const char* filename,
ErrorReporter* error_reporter)
: MMAPAllocation(filename, error_reporter) {
@@ -245,6 +275,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
add_scalar_float32(builtin->proj_clip);
};
+ auto add_mean_params = [&add_scalar_int32](void* data) {
+ auto builtin = reinterpret_cast<TfLiteMeanParams*>(data);
+ add_scalar_int32(builtin->keep_dims);
+ };
+
#if 0
auto add_reshape_params = [&](void* data) {
auto builtin = reinterpret_cast<TfLiteReshapeParams*>(data);
@@ -262,8 +297,9 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
augmented_inputs.push_back(next_id++);
};
#endif
-
+ int nnapi_version = 10;
ANeuralNetworksOperationType nn_op_type;
+
switch (builtin) {
case tflite::BuiltinOperator_ADD:
nn_op_type = ANEURALNETWORKS_ADD;
@@ -337,6 +373,23 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
nn_op_type = ANEURALNETWORKS_LSTM;
break;
}
+ case tflite::BuiltinOperator_PAD:
+ nnapi_version = 11; // require NNAPI 1.1
+ nn_op_type = ANEURALNETWORKS_PAD;
+ break;
+ case tflite::BuiltinOperator_MEAN:
+ nnapi_version = 11; // require NNAPI 1.1
+ add_mean_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_MEAN;
+ break;
+ case tflite::BuiltinOperator_DIV:
+ nnapi_version = 11; // require NNAPI 1.1
+ nn_op_type = ANEURALNETWORKS_DIV;
+ break;
+ case tflite::BuiltinOperator_SUB:
+ nnapi_version = 11; // require NNAPI 1.1
+ nn_op_type = ANEURALNETWORKS_SUB;
+ break;
case tflite::BuiltinOperator_CONCAT_EMBEDDINGS:
case tflite::BuiltinOperator_LSH_PROJECTION:
case tflite::BuiltinOperator_SVDF:
@@ -350,7 +403,6 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case tflite::BuiltinOperator_L2_NORMALIZATION:
case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION:
- case tflite::BuiltinOperator_PAD:
case tflite::BuiltinOperator_PADV2:
case tflite::BuiltinOperator_RESIZE_BILINEAR:
case tflite::BuiltinOperator_CALL:
@@ -361,9 +413,6 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_BATCH_TO_SPACE_ND:
case tflite::BuiltinOperator_TOPK_V2:
case tflite::BuiltinOperator_TRANSPOSE:
- case tflite::BuiltinOperator_MEAN:
- case tflite::BuiltinOperator_DIV:
- case tflite::BuiltinOperator_SUB:
case tflite::BuiltinOperator_SPLIT:
case tflite::BuiltinOperator_SQUEEZE:
case tflite::BuiltinOperator_STRIDED_SLICE:
@@ -382,6 +431,8 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_LESS_EQUAL:
case tflite::BuiltinOperator_NEG:
case tflite::BuiltinOperator_SELECT:
+ case tflite::BuiltinOperator_SLICE:
+ case tflite::BuiltinOperator_SIN:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
@@ -391,6 +442,10 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
break;
}
+ if (nnapi_version == 11 && kAndroidSdkVersion < 28) {
+ FATAL("Op %d needs NNAPI1.1", builtin);
+ }
+
// Add the operation.
CHECK_NN(ANeuralNetworksModel_addOperation(
nn_model, nn_op_type, static_cast<uint32_t>(augmented_inputs.size()),
diff --git a/tensorflow/contrib/lite/op_resolver.cc b/tensorflow/contrib/lite/op_resolver.cc
new file mode 100644
index 0000000000..f6e435e982
--- /dev/null
+++ b/tensorflow/contrib/lite/op_resolver.cc
@@ -0,0 +1,57 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+
+const TfLiteRegistration* MutableOpResolver::FindOp(tflite::BuiltinOperator op,
+ int version) const {
+ auto it = builtins_.find(std::make_pair(op, version));
+ return it != builtins_.end() ? &it->second : nullptr;
+}
+
+const TfLiteRegistration* MutableOpResolver::FindOp(const char* op,
+ int version) const {
+ auto it = custom_ops_.find(std::make_pair(op, version));
+ return it != custom_ops_.end() ? &it->second : nullptr;
+}
+
+void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
+ TfLiteRegistration* registration,
+ int min_version, int max_version) {
+ for (int version = min_version; version <= max_version; ++version) {
+ TfLiteRegistration new_registration = *registration;
+ new_registration.builtin_code = op;
+ new_registration.version = version;
+ auto op_key = std::make_pair(op, version);
+ builtins_[op_key] = new_registration;
+ }
+}
+
+void MutableOpResolver::AddCustom(const char* name,
+ TfLiteRegistration* registration,
+ int min_version, int max_version) {
+ for (int version = min_version; version <= max_version; ++version) {
+ TfLiteRegistration new_registration = *registration;
+ new_registration.builtin_code = BuiltinOperator_CUSTOM;
+ new_registration.version = version;
+ auto op_key = std::make_pair(name, version);
+ custom_ops_[op_key] = new_registration;
+ }
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/op_resolver.h b/tensorflow/contrib/lite/op_resolver.h
new file mode 100644
index 0000000000..38a2706942
--- /dev/null
+++ b/tensorflow/contrib/lite/op_resolver.h
@@ -0,0 +1,94 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
+
+#include <unordered_map>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// Abstract interface that returns TfLiteRegistrations given op codes or custom
+// op names. This is the mechanism that ops being referenced in the flatbuffer
+// model are mapped to executable function pointers (TfLiteRegistrations).
+class OpResolver {
+ public:
+ // Finds the op registration for a builtin operator by enum code.
+ virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const = 0;
+ // Finds the op registration of a custom operator by op name.
+ virtual const TfLiteRegistration* FindOp(const char* op,
+ int version) const = 0;
+ virtual ~OpResolver() {}
+};
+
+// Some versions of gcc doesn't support partial specialization in class scope,
+// so these are defined in a namescope.
+namespace op_resolver_hasher {
+template <typename V>
+struct ValueHasher {
+ size_t operator()(const V& v) const { return std::hash<V>()(v); }
+};
+
+template <>
+struct ValueHasher<tflite::BuiltinOperator> {
+ size_t operator()(const tflite::BuiltinOperator& v) const {
+ return std::hash<int>()(static_cast<int>(v));
+ }
+};
+
+template <typename T>
+struct OperatorKeyHasher {
+ size_t operator()(const T& x) const {
+ size_t a = ValueHasher<typename T::first_type>()(x.first);
+ size_t b = ValueHasher<typename T::second_type>()(x.second);
+ // Hash combinator used by TensorFlow core.
+ return a ^ (b + 0x9e3779b97f4a7800ULL + (a << 10) + (a >> 4));
+ }
+};
+} // namespace op_resolver_hasher
+
+// An OpResolver that is mutable, also used as the op in gen_op_registration.
+// A typical usage:
+// MutableOpResolver resolver;
+// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
+// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
+// InterpreterBuilder(model, resolver)(&interpreter);
+class MutableOpResolver : public OpResolver {
+ public:
+ const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const override;
+ const TfLiteRegistration* FindOp(const char* op, int version) const override;
+ void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+ void AddCustom(const char* name, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+
+ private:
+ typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
+ typedef std::pair<std::string, int> CustomOperatorKey;
+
+ std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
+ op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
+ builtins_;
+ std::unordered_map<CustomOperatorKey, TfLiteRegistration,
+ op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
+ custom_ops_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/op_resolver_test.cc b/tensorflow/contrib/lite/op_resolver_test.cc
new file mode 100644
index 0000000000..10b7e31972
--- /dev/null
+++ b/tensorflow/contrib/lite/op_resolver_test.cc
@@ -0,0 +1,129 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/op_resolver.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+namespace {
+
+// We need some dummy functions to identify the registrations.
+TfLiteStatus DummyInvoke(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+TfLiteRegistration* GetDummyRegistration() {
+ static TfLiteRegistration registration = {
+ .init = nullptr,
+ .free = nullptr,
+ .prepare = nullptr,
+ .invoke = DummyInvoke,
+ };
+ return &registration;
+}
+
+TEST(MutableOpResolverTest, FinOp) {
+ MutableOpResolver resolver;
+ resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
+
+ const TfLiteRegistration* found_registration =
+ resolver.FindOp(BuiltinOperator_ADD, 1);
+ ASSERT_NE(found_registration, nullptr);
+ EXPECT_TRUE(found_registration->invoke == DummyInvoke);
+ EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_ADD);
+ EXPECT_EQ(found_registration->version, 1);
+}
+
+TEST(MutableOpResolverTest, FindMissingOp) {
+ MutableOpResolver resolver;
+ resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
+
+ const TfLiteRegistration* found_registration =
+ resolver.FindOp(BuiltinOperator_CONV_2D, 1);
+ EXPECT_EQ(found_registration, nullptr);
+}
+
+TEST(MutableOpResolverTest, RegisterOpWithMultipleVersions) {
+ MutableOpResolver resolver;
+ // The kernel supports version 2 and 3
+ resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2, 3);
+
+ const TfLiteRegistration* found_registration;
+
+ found_registration = resolver.FindOp(BuiltinOperator_ADD, 2);
+ ASSERT_NE(found_registration, nullptr);
+ EXPECT_TRUE(found_registration->invoke == DummyInvoke);
+ EXPECT_EQ(found_registration->version, 2);
+
+ found_registration = resolver.FindOp(BuiltinOperator_ADD, 3);
+ ASSERT_NE(found_registration, nullptr);
+ EXPECT_TRUE(found_registration->invoke == DummyInvoke);
+ EXPECT_EQ(found_registration->version, 3);
+}
+
+TEST(MutableOpResolverTest, FindOpWithUnsupportedVersions) {
+ MutableOpResolver resolver;
+ // The kernel supports version 2 and 3
+ resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2, 3);
+
+ const TfLiteRegistration* found_registration;
+
+ found_registration = resolver.FindOp(BuiltinOperator_ADD, 1);
+ EXPECT_EQ(found_registration, nullptr);
+
+ found_registration = resolver.FindOp(BuiltinOperator_ADD, 4);
+ EXPECT_EQ(found_registration, nullptr);
+}
+
+TEST(MutableOpResolverTest, FindCustomOp) {
+ MutableOpResolver resolver;
+ resolver.AddCustom("AWESOME", GetDummyRegistration());
+
+ const TfLiteRegistration* found_registration = resolver.FindOp("AWESOME", 1);
+ ASSERT_NE(found_registration, nullptr);
+ EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_CUSTOM);
+ EXPECT_TRUE(found_registration->invoke == DummyInvoke);
+ EXPECT_EQ(found_registration->version, 1);
+ // TODO(ycling): The `custom_name` in TfLiteRegistration isn't properly
+ // filled yet. Fix this and add tests.
+}
+
+TEST(MutableOpResolverTest, FindMissingCustomOp) {
+ MutableOpResolver resolver;
+ resolver.AddCustom("AWESOME", GetDummyRegistration());
+
+ const TfLiteRegistration* found_registration =
+ resolver.FindOp("EXCELLENT", 1);
+ EXPECT_EQ(found_registration, nullptr);
+}
+
+TEST(MutableOpResolverTest, FindCustomOpWithUnsupportedVersion) {
+ MutableOpResolver resolver;
+ resolver.AddCustom("AWESOME", GetDummyRegistration());
+
+ const TfLiteRegistration* found_registration = resolver.FindOp("AWESOME", 2);
+ EXPECT_EQ(found_registration, nullptr);
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 2f5c39e7d7..68af29d451 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -142,6 +142,8 @@ enum BuiltinOperator : byte {
GREATER_EQUAL = 62,
LESS_EQUAL = 63,
SELECT = 64,
+ SLICE = 65,
+ SIN = 66,
}
// Options for the builtin operators.
@@ -193,6 +195,7 @@ union BuiltinOptions {
GreaterEqualOptions,
LessEqualOptions,
SelectOptions,
+ SliceOptions,
}
enum Padding : byte { SAME, VALID }
@@ -436,11 +439,18 @@ table NegOptions {
table SelectOptions {
}
+table SliceOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
builtin_code:BuiltinOperator;
custom_code:string;
+
+ // The version of the operator. The version need to be bumped whenever new
+ // parameters are introduced into an op.
+ version:int = 1;
}
enum CustomOptionsFormat : byte {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index a2f0c8cdd2..3f6bbf0566 100644..100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -172,6 +172,9 @@ struct NegOptionsT;
struct SelectOptions;
struct SelectOptionsT;
+struct SliceOptions;
+struct SliceOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -296,11 +299,13 @@ enum BuiltinOperator {
BuiltinOperator_GREATER_EQUAL = 62,
BuiltinOperator_LESS_EQUAL = 63,
BuiltinOperator_SELECT = 64,
+ BuiltinOperator_SLICE = 65,
+ BuiltinOperator_SIN = 66,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_SELECT
+ BuiltinOperator_MAX = BuiltinOperator_SIN
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[64] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[66] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -365,7 +370,9 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[64] {
BuiltinOperator_GREATER,
BuiltinOperator_GREATER_EQUAL,
BuiltinOperator_LESS_EQUAL,
- BuiltinOperator_SELECT
+ BuiltinOperator_SELECT,
+ BuiltinOperator_SLICE,
+ BuiltinOperator_SIN
};
return values;
}
@@ -437,6 +444,8 @@ inline const char **EnumNamesBuiltinOperator() {
"GREATER_EQUAL",
"LESS_EQUAL",
"SELECT",
+ "SLICE",
+ "SIN",
nullptr
};
return names;
@@ -496,11 +505,12 @@ enum BuiltinOptions {
BuiltinOptions_GreaterEqualOptions = 45,
BuiltinOptions_LessEqualOptions = 46,
BuiltinOptions_SelectOptions = 47,
+ BuiltinOptions_SliceOptions = 48,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_SelectOptions
+ BuiltinOptions_MAX = BuiltinOptions_SliceOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[48] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[49] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -549,7 +559,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[48] {
BuiltinOptions_GreaterOptions,
BuiltinOptions_GreaterEqualOptions,
BuiltinOptions_LessEqualOptions,
- BuiltinOptions_SelectOptions
+ BuiltinOptions_SelectOptions,
+ BuiltinOptions_SliceOptions
};
return values;
}
@@ -604,6 +615,7 @@ inline const char **EnumNamesBuiltinOptions() {
"GreaterEqualOptions",
"LessEqualOptions",
"SelectOptions",
+ "SliceOptions",
nullptr
};
return names;
@@ -806,6 +818,10 @@ template<> struct BuiltinOptionsTraits<SelectOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_SelectOptions;
};
+template<> struct BuiltinOptionsTraits<SliceOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_SliceOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1213,6 +1229,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_SelectOptions ?
reinterpret_cast<const SelectOptionsT *>(value) : nullptr;
}
+ SliceOptionsT *AsSliceOptions() {
+ return type == BuiltinOptions_SliceOptions ?
+ reinterpret_cast<SliceOptionsT *>(value) : nullptr;
+ }
+ const SliceOptionsT *AsSliceOptions() const {
+ return type == BuiltinOptions_SliceOptions ?
+ reinterpret_cast<const SliceOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -4380,12 +4404,54 @@ inline flatbuffers::Offset<SelectOptions> CreateSelectOptions(
flatbuffers::Offset<SelectOptions> CreateSelectOptions(flatbuffers::FlatBufferBuilder &_fbb, const SelectOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct SliceOptionsT : public flatbuffers::NativeTable {
+ typedef SliceOptions TableType;
+ SliceOptionsT() {
+ }
+};
+
+struct SliceOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SliceOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ SliceOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(SliceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<SliceOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const SliceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct SliceOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit SliceOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SliceOptionsBuilder &operator=(const SliceOptionsBuilder &);
+ flatbuffers::Offset<SliceOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SliceOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SliceOptions> CreateSliceOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ SliceOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<SliceOptions> CreateSliceOptions(flatbuffers::FlatBufferBuilder &_fbb, const SliceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
std::string custom_code;
+ int32_t version;
OperatorCodeT()
- : builtin_code(BuiltinOperator_ADD) {
+ : builtin_code(BuiltinOperator_ADD),
+ version(1) {
}
};
@@ -4393,7 +4459,8 @@ struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef OperatorCodeT NativeTableType;
enum {
VT_BUILTIN_CODE = 4,
- VT_CUSTOM_CODE = 6
+ VT_CUSTOM_CODE = 6,
+ VT_VERSION = 8
};
BuiltinOperator builtin_code() const {
return static_cast<BuiltinOperator>(GetField<int8_t>(VT_BUILTIN_CODE, 0));
@@ -4401,11 +4468,15 @@ struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const flatbuffers::String *custom_code() const {
return GetPointer<const flatbuffers::String *>(VT_CUSTOM_CODE);
}
+ int32_t version() const {
+ return GetField<int32_t>(VT_VERSION, 1);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_BUILTIN_CODE) &&
VerifyOffset(verifier, VT_CUSTOM_CODE) &&
verifier.Verify(custom_code()) &&
+ VerifyField<int32_t>(verifier, VT_VERSION) &&
verifier.EndTable();
}
OperatorCodeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -4422,6 +4493,9 @@ struct OperatorCodeBuilder {
void add_custom_code(flatbuffers::Offset<flatbuffers::String> custom_code) {
fbb_.AddOffset(OperatorCode::VT_CUSTOM_CODE, custom_code);
}
+ void add_version(int32_t version) {
+ fbb_.AddElement<int32_t>(OperatorCode::VT_VERSION, version, 1);
+ }
explicit OperatorCodeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -4437,8 +4511,10 @@ struct OperatorCodeBuilder {
inline flatbuffers::Offset<OperatorCode> CreateOperatorCode(
flatbuffers::FlatBufferBuilder &_fbb,
BuiltinOperator builtin_code = BuiltinOperator_ADD,
- flatbuffers::Offset<flatbuffers::String> custom_code = 0) {
+ flatbuffers::Offset<flatbuffers::String> custom_code = 0,
+ int32_t version = 1) {
OperatorCodeBuilder builder_(_fbb);
+ builder_.add_version(version);
builder_.add_custom_code(custom_code);
builder_.add_builtin_code(builtin_code);
return builder_.Finish();
@@ -4447,11 +4523,13 @@ inline flatbuffers::Offset<OperatorCode> CreateOperatorCode(
inline flatbuffers::Offset<OperatorCode> CreateOperatorCodeDirect(
flatbuffers::FlatBufferBuilder &_fbb,
BuiltinOperator builtin_code = BuiltinOperator_ADD,
- const char *custom_code = nullptr) {
+ const char *custom_code = nullptr,
+ int32_t version = 1) {
return tflite::CreateOperatorCode(
_fbb,
builtin_code,
- custom_code ? _fbb.CreateString(custom_code) : 0);
+ custom_code ? _fbb.CreateString(custom_code) : 0,
+ version);
}
flatbuffers::Offset<OperatorCode> CreateOperatorCode(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -4638,6 +4716,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const SelectOptions *builtin_options_as_SelectOptions() const {
return builtin_options_type() == BuiltinOptions_SelectOptions ? static_cast<const SelectOptions *>(builtin_options()) : nullptr;
}
+ const SliceOptions *builtin_options_as_SliceOptions() const {
+ return builtin_options_type() == BuiltinOptions_SliceOptions ? static_cast<const SliceOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -4852,6 +4933,10 @@ template<> inline const SelectOptions *Operator::builtin_options_as<SelectOption
return builtin_options_as_SelectOptions();
}
+template<> inline const SliceOptions *Operator::builtin_options_as<SliceOptions>() const {
+ return builtin_options_as_SliceOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -6616,6 +6701,29 @@ inline flatbuffers::Offset<SelectOptions> CreateSelectOptions(flatbuffers::FlatB
_fbb);
}
+inline SliceOptionsT *SliceOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new SliceOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void SliceOptions::UnPackTo(SliceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<SliceOptions> SliceOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SliceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateSliceOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<SliceOptions> CreateSliceOptions(flatbuffers::FlatBufferBuilder &_fbb, const SliceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SliceOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateSliceOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -6627,6 +6735,7 @@ inline void OperatorCode::UnPackTo(OperatorCodeT *_o, const flatbuffers::resolve
(void)_resolver;
{ auto _e = builtin_code(); _o->builtin_code = _e; };
{ auto _e = custom_code(); if (_e) _o->custom_code = _e->str(); };
+ { auto _e = version(); _o->version = _e; };
}
inline flatbuffers::Offset<OperatorCode> OperatorCode::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -6639,10 +6748,12 @@ inline flatbuffers::Offset<OperatorCode> CreateOperatorCode(flatbuffers::FlatBuf
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OperatorCodeT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _builtin_code = _o->builtin_code;
auto _custom_code = _o->custom_code.empty() ? 0 : _fbb.CreateString(_o->custom_code);
+ auto _version = _o->version;
return tflite::CreateOperatorCode(
_fbb,
_builtin_code,
- _custom_code);
+ _custom_code,
+ _version);
}
inline OperatorT *Operator::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -6987,6 +7098,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const SelectOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_SliceOptions: {
+ auto ptr = reinterpret_cast<const SliceOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -7193,6 +7308,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const SelectOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_SliceOptions: {
+ auto ptr = reinterpret_cast<const SliceOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -7387,6 +7506,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const SelectOptionsT *>(value);
return CreateSelectOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_SliceOptions: {
+ auto ptr = reinterpret_cast<const SliceOptionsT *>(value);
+ return CreateSliceOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -7581,6 +7704,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new SelectOptionsT(*reinterpret_cast<SelectOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_SliceOptions: {
+ value = new SliceOptionsT(*reinterpret_cast<SliceOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -7823,6 +7950,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_SliceOptions: {
+ auto ptr = reinterpret_cast<SliceOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index f89c0d28d3..34f1f1b6b0 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -55,6 +55,8 @@ gen_zipped_test_files(
"reshape.zip",
"resize_bilinear.zip",
"sigmoid.zip",
+ "sin.zip",
+ "slice.zip",
"softmax.zip",
"space_to_batch_nd.zip",
"space_to_depth.zip",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index f7cc7da900..da57247a37 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -20,6 +20,9 @@ Usage:
generate_examples <output directory>
bazel run //tensorflow/contrib/lite/testing:generate_examples
+
+To more easily debug failures use (or override) the --save_graphdefs flag to
+place text proto graphdefs into the generated zip files.
"""
from __future__ import absolute_import
from __future__ import division
@@ -90,15 +93,12 @@ KNOWN_BUGS = {
r"fully_connected.*transpose_.=True": "67586970",
# Softmax graphs are too complex.
r"softmax.*dim=0": "67749831",
- r"softmax.*input_shape=\[1,3,4,3\]": "67749831",
# SpaceToDepth only supports float32.
r"space_to_depth.*(float16|int32|uint8|int64)": "68018134",
# BatchToSpaceND only supports 4D tensors.
r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733",
# Div will use floordiv.
r"div.*int32": "72051395",
- # TOCO require matching dimensions in strided_slice.
- r"strided_slice.*begin=\[0\].*end=\[1\].*": "73170889",
# No support for SplitV
r"split.*num_or_size_splits=\[2,2\]": "73377559",
# Needs support for dimensions other than the last one in argmax.
@@ -430,7 +430,7 @@ def make_zip_of_tests(zip_path,
report["toco_log"] = toco_log
if FLAGS.save_graphdefs:
- archive.writestr(label + ".pb",
+ archive.writestr(label + ".pbtxt",
text_format.MessageToString(graph_def),
zipfile.ZIP_DEFLATED)
@@ -1812,7 +1812,19 @@ def make_strided_slice_tests(zip_path):
"shrink_axis_mask": [None, 1, 8, 11, 15, -1],
"constant_indices": [False, True],
},
- # TODO(b/73170889) Restore test parameters removed in cl/191608113.
+ # Begin, end, strides dim are different from input shape
+ {
+ "dtype": [tf.float32],
+ "index_type": [tf.int32],
+ "input_shape": [[12, 2, 2, 5]],
+ "begin": [[0]],
+ "end": [[1]],
+ "strides": [None, [1]],
+ "begin_mask": [0],
+ "end_mask": [0],
+ "shrink_axis_mask": [1],
+ "constant_indices": [True],
+ },
# 2-D
{
"dtype": [tf.float32, tf.int32, tf.int64],
@@ -2242,6 +2254,32 @@ def make_neg_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_sin_tests(zip_path):
+ """Make a set of tests to do sin."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32],
+ "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ """Build the sin op testing graph."""
+ input_value = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape"])
+ out = tf.sin(input_value)
+ return [input_value], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape"])
+ return [input_value], sess.run(
+ outputs, feed_dict={inputs[0]: input_value})
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_where_tests(zip_path):
"""Make a set of tests to do where."""
@@ -2274,6 +2312,62 @@ def make_where_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+def make_slice_tests(zip_path):
+ """Make a set of tests to do slice."""
+
+ # TODO(renjieliu): add test/support for uint8.
+ test_parameters = [
+ # 4-D
+ {
+ "dtype": [tf.float32, tf.int32, tf.int64],
+ "index_type": [tf.int32, tf.int64],
+ "input_shape": [[12, 2, 2, 5]],
+ "begin": [[0, 0, 0, 0], [1, 0, 1, 0]],
+ "size": [[8, 2, 2, 3], [11, 2, 1, 5]],
+ },
+ # 2-D
+ {
+ "dtype": [tf.float32, tf.int32, tf.int64],
+ "index_type": [tf.int32, tf.int64],
+ "input_shape": [[2, 3]],
+ "begin": [[0, 0], [1, 0]],
+ "size": [[2, 3], [2, 2]],
+ },
+ ]
+
+ def build_graph(parameters):
+ """Build graph for slice test."""
+ input_tensor = tf.placeholder(
+ dtype=parameters["dtype"],
+ name="input",
+ shape=parameters["input_shape"])
+ begin = tf.placeholder(
+ dtype=parameters["index_type"],
+ name="begin",
+ shape=[len(parameters["input_shape"])])
+ size = tf.placeholder(
+ dtype=parameters["index_type"],
+ name="size",
+ shape=[len(parameters["input_shape"])])
+ tensors = [input_tensor, begin, size]
+ out = tf.slice(input_tensor, begin, size)
+ return tensors, [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ """Build inputs for slice test."""
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ index_type = _TF_TYPE_INFO[parameters["index_type"]][0]
+
+ begin_values = np.array(parameters["begin"]).astype(index_type)
+ size_values = np.array(parameters["size"]).astype(index_type)
+ values = [input_values, begin_values, size_values]
+
+ return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
# Toco binary path provided by the generate rule.
bin_path = None
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 49762bdfe7..6ecaf2a355 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -63,10 +63,6 @@ std::map<string, string> kBrokenTests = {
// L2Norm only supports tensors with 4D or fewer.
{R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
- // BatchToSpaceND doesn't support cropping. This catches test cases with
- // non-const tensors as crops.
- {R"(^\/batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\])", "70594634"},
-
// SpaceToBatchND only supports 4D tensors.
{R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"},
@@ -204,7 +200,7 @@ std::vector<string> UnarchiveZipAndFindTestNames(const string& zip_file_name) {
class OpsTest : public ::testing::TestWithParam<string> {};
-TEST_P(OpsTest, RunStuff) {
+TEST_P(OpsTest, RunZipTests) {
string test_path = GetParam();
string tflite_test_case = test_path + "_tests.txt";
string tflite_dir = test_path.substr(0, test_path.find_last_of("/"));
@@ -227,7 +223,9 @@ TEST_P(OpsTest, RunStuff) {
EXPECT_TRUE(result) << test_driver.GetErrorMessage();
} else {
if (FLAGS_ignore_known_bugs) {
- EXPECT_FALSE(result);
+ EXPECT_FALSE(result) << "Test was expected to fail but is now passing; "
+ "you can mark http://b/"
+ << bug_number << " as fixed! Yay!";
} else {
EXPECT_TRUE(result) << test_driver.GetErrorMessage()
<< ": Possibly due to http://b/" << bug_number;
@@ -235,12 +233,29 @@ TEST_P(OpsTest, RunStuff) {
}
}
+struct ZipPathParamName {
+ template <class ParamType>
+ string operator()(const ::testing::TestParamInfo<ParamType>& info) const {
+ string param_name = info.param;
+ size_t last_slash = param_name.find_last_of("\\/");
+ if (last_slash != string::npos) {
+ param_name = param_name.substr(last_slash);
+ }
+ for (size_t index = 0; index < param_name.size(); ++index) {
+ if (!isalnum(param_name[index]) && param_name[index] != '_')
+ param_name[index] = '_';
+ }
+ return param_name;
+ }
+};
+
// Instantiate a test. This assumes `zip_base`.zip is a declared data file
// of this test.
-#define INSTANTIATE_TESTS(zip_base) \
- INSTANTIATE_TEST_CASE_P( \
- zip_base, OpsTest, \
- ::testing::ValuesIn(UnarchiveZipAndFindTestNames(#zip_base ".zip")));
+#define INSTANTIATE_TESTS(zip_base) \
+ INSTANTIATE_TEST_CASE_P( \
+ zip_base, OpsTest, \
+ ::testing::ValuesIn(UnarchiveZipAndFindTestNames(#zip_base ".zip")), \
+ ZipPathParamName());
INSTANTIATE_TESTS(add)
INSTANTIATE_TESTS(arg_max)
@@ -281,6 +296,8 @@ INSTANTIATE_TESTS(relu6)
INSTANTIATE_TESTS(reshape)
INSTANTIATE_TESTS(resize_bilinear)
INSTANTIATE_TESTS(sigmoid)
+INSTANTIATE_TESTS(sin)
+INSTANTIATE_TESTS(slice)
INSTANTIATE_TESTS(softmax)
INSTANTIATE_TESTS(space_to_batch_nd)
INSTANTIATE_TESTS(space_to_depth)
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 01ce0d9db2..b8acc9a8e0 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -273,6 +273,7 @@ cc_library(
"graph_transformations/resolve_constant_range.cc",
"graph_transformations/resolve_constant_reshape.cc",
"graph_transformations/resolve_constant_shape_or_rank.cc",
+ "graph_transformations/resolve_constant_slice.cc",
"graph_transformations/resolve_constant_stack.cc",
"graph_transformations/resolve_constant_strided_slice.cc",
"graph_transformations/resolve_constant_transpose.cc",
diff --git a/tensorflow/contrib/lite/toco/format_port.h b/tensorflow/contrib/lite/toco/format_port.h
index eb81e90faf..44e6684571 100644
--- a/tensorflow/contrib/lite/toco/format_port.h
+++ b/tensorflow/contrib/lite/toco/format_port.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// This file is used to provide equivalents of internal util::format::FormatF
-// and util::format::AppendF. Unfortunately, type safety is not as good as a
+// This file is used to provide equivalents of internal absl::FormatF
+// and absl::StrAppendFormat. Unfortunately, type safety is not as good as a
// a full C++ example.
// TODO(aselle): When absl adds support for StrFormat, use that instead.
#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 4e3ea72182..8da242aa9c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -182,6 +182,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSlice)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStack)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index 437e30a918..d63ee7c951 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -188,6 +188,32 @@ bool HardcodeMinMaxFromFirstInput(Model* model, Operator* op) {
return true;
}
+bool HardcodeMinMaxForSelect(Model* model, Operator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.minmax) {
+ return false;
+ }
+ const auto& input_array_1 = model->GetArray(op->inputs[1]);
+ if (!input_array_1.minmax) {
+ return false;
+ }
+ const auto& input_array_2 = model->GetArray(op->inputs[2]);
+ if (!input_array_2.minmax) {
+ return false;
+ }
+
+ const auto& input_minmax_1 = input_array_1.GetMinMax();
+ const auto& input_minmax_2 = input_array_2.GetMinMax();
+
+ CHECK_EQ(input_minmax_1.min, input_minmax_2.min);
+ CHECK_EQ(input_minmax_1.max, input_minmax_2.max);
+ CHECK(!output_array.minmax);
+ auto& output_minmax = output_array.GetOrCreateMinMax();
+ output_minmax.min = input_minmax_1.min;
+ output_minmax.max = input_minmax_1.max;
+ return true;
+}
+
bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min,
double max) {
CHECK_EQ(op->outputs.size(), 1);
@@ -345,7 +371,9 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
case OperatorType::kMean:
changed = HardcodeMinMaxFromFirstInput(model, op);
break;
-
+ case OperatorType::kSelect:
+ changed = HardcodeMinMaxForSelect(model, op);
+ break;
case OperatorType::kLogistic:
// We hardcode quantization_params to: zero_point=0, scale=1/256.
// This choice of minmax is the one that is equivalent to that.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
index 0bce183c18..6d51fc8c31 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
@@ -102,6 +102,7 @@ bool DoesOpBlockBackwardPropagation(const Operator& op) {
// Gathers need their parameters changed to the appropriate data type.
case OperatorType::kTensorFlowReshape:
case OperatorType::kTranspose:
+ case OperatorType::kSelect:
// Reshapes and transposes don't change values.
return false;
default:
@@ -113,6 +114,8 @@ bool DoesOpBlockBackwardPropagation(const Operator& op) {
// propagation.
bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) {
switch (op.type) {
+ case OperatorType::kSelect:
+ return input_index == 0;
case OperatorType::kGather:
// Ignore gather indices.
return input_index != 0;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 52b739c5e2..9d1d27f3ef 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1514,6 +1514,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kCast:
case OperatorType::kFloor:
case OperatorType::kExp:
+ case OperatorType::kSin:
ProcessSimpleOperator(model, op, 0);
break;
case OperatorType::kGather:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index a1ca7371c8..142841fcc4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -59,7 +59,8 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kTensorFlowGreater ||
type == OperatorType::kTensorFlowGreaterEqual ||
type == OperatorType::kTensorFlowLess ||
- type == OperatorType::kTensorFlowLessEqual;
+ type == OperatorType::kTensorFlowLessEqual ||
+ type == OperatorType::kSelect;
}
const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
index 3e021b819f..a950fe6442 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
@@ -85,9 +85,11 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
"Removing %s, keeping its non-constant input array %s and removing %s",
LogName(*passthru_op), main_input_name, output_name);
RerouteEdges(output_name, main_input_name, model);
- } else if (IsDiscardableArray(*model, main_input_name)) {
+ } else if (IsDiscardableArray(*model, main_input_name) &&
+ !IsConstantParameterArray(*model, main_input_name)) {
transformation->AddMessageF(
- "Removing %s, keeping its output array %s and removing input %s",
+ "Removing %s, keeping its output array %s and removing non-constant "
+ "input %s",
LogName(*passthru_op), output_name, main_input_name);
RerouteEdges(main_input_name, output_name, model);
} else {
@@ -95,10 +97,23 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
"Cannot remove %s, neither its main input nor its output may be "
"discarded",
LogName(*passthru_op));
- return false;
+ if (passthru_op->type != OperatorType::kTensorFlowReshape &&
+ model->GetArray(main_input_name).has_shape()) {
+ // We can't remove either array but we can remove the op. Converting it to
+ // a reshape gives us some hope of later on fixing that (either in the
+ // final runtime or as an additional fixup step).
+ //
+ // Note that we don't try to insert copies in place of reshapes as the
+ // copy itself is a trivial reshape and we'd go into an infinite loop!
+ transformation->AddMessageF("Replacing with a copy (reshape) instead");
+ InsertCopyOperator(model, main_input_name, output_name);
+ } else {
+ return false;
+ }
}
// Remove the pass-through node.
+ CHECK_EQ(passthru_it->get(), passthru_op);
model->operators.erase(passthru_it);
// Remove any array that is no longer used.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc
new file mode 100644
index 0000000000..b35c3e19c4
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc
@@ -0,0 +1,165 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+template <ArrayDataType Type>
+bool Slice(SliceOperator const& op, Array const& input_array,
+ Array* output_array) {
+ // Implementation is taken from the tflite kernel.
+
+ CHECK(input_array.data_type == Type);
+ CHECK(output_array->data_type == Type);
+ const auto& input_data = input_array.GetBuffer<Type>().data;
+
+ // Create a buffer for the output array.
+ std::vector<DataType<Type>>& output_data =
+ output_array->GetMutableBuffer<Type>().data;
+ output_data.resize(RequiredBufferSizeForShape(output_array->shape()));
+
+ std::vector<int> size = op.size;
+ if (size.size() != op.begin.size()) {
+ // Broadcast the end positions.
+ CHECK_EQ(op.size.size(), 1);
+ int broadcast_size = size[0];
+ while (size.size() < op.begin.size()) size.push_back(broadcast_size);
+ }
+
+ // Calculate begin and end indices along each dimension.
+ CHECK_LE(op.begin.size(), 4);
+ CHECK_LE(size.size(), 4);
+ std::vector<int> begin = op.begin;
+ std::vector<int> end;
+ for (int i = 0; i < begin.size(); ++i) {
+ int dim_size = size[i];
+ if (dim_size == -1) {
+ // -1 means the rest of the dimension.
+ dim_size = input_array.shape().dims()[i] - begin[i];
+ }
+ CHECK_GE(dim_size, 1);
+ end.push_back(begin[i] + dim_size - 1);
+ }
+
+ // Pad out so that we always have 4 dims, makes this loop easier.
+ while (begin.size() < 4) begin.insert(begin.begin(), 0);
+ while (end.size() < 4) end.insert(end.begin(), 0);
+ Shape padded_shape = input_array.shape();
+ while (padded_shape.dimensions_count() < 4) {
+ padded_shape.mutable_dims()->insert(padded_shape.mutable_dims()->begin(),
+ 1);
+ }
+
+ auto* out_ptr = output_data.data();
+ for (int in_b = begin[0]; in_b <= end[0]; ++in_b) {
+ for (int in_h = begin[1]; in_h <= end[1]; ++in_h) {
+ for (int in_w = begin[2]; in_w <= end[2]; ++in_w) {
+ for (int in_d = begin[3]; in_d <= end[3]; ++in_d) {
+ *out_ptr++ =
+ input_data[Offset(padded_shape, {in_b, in_h, in_w, in_d})];
+ }
+ }
+ }
+ }
+
+ return true;
+}
+
+} // namespace
+
+bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ const auto* base_op = it->get();
+ if (base_op->type != OperatorType::kSlice) {
+ return false;
+ }
+
+ const SliceOperator* op = static_cast<const SliceOperator*>(base_op);
+
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes.
+ return false;
+ }
+
+ if (!output_array.has_shape()) {
+ // Yield until the output shape has been set by PropagateFixedShapes.
+ return false;
+ }
+
+ if (op->begin.empty() || op->size.empty()) {
+ // Attributes have not resolved yet.
+ return false;
+ }
+
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.has_shape()) {
+ // Yield until the value shape has been resolved.
+ return false;
+ }
+ if (!IsConstantParameterArray(*model, op->inputs[0])) {
+ // Yield until the value is constant.
+ return false;
+ }
+
+ CHECK(!output_array.buffer);
+ switch (output_array.data_type) {
+ case ArrayDataType::kFloat:
+ if (!Slice<ArrayDataType::kFloat>(*op, input_array, &output_array)) {
+ return false;
+ }
+ break;
+ case ArrayDataType::kUint8:
+ if (!Slice<ArrayDataType::kUint8>(*op, input_array, &output_array)) {
+ return false;
+ }
+ break;
+ case ArrayDataType::kInt32:
+ if (!Slice<ArrayDataType::kInt32>(*op, input_array, &output_array)) {
+ return false;
+ }
+ break;
+ case ArrayDataType::kInt64:
+ if (!Slice<ArrayDataType::kInt64>(*op, input_array, &output_array)) {
+ return false;
+ }
+ break;
+ default:
+ LOG(FATAL) << "Unsupported data type input to Slice op with output \""
+ << op->outputs[0] << "\"";
+ break;
+ }
+
+ // Erase input array if no longer used.
+ if (IsDiscardableArray(*model, op->inputs[0]) &&
+ CountOpsWithInput(*model, op->inputs[0]) == 1) {
+ model->EraseArray(op->inputs[0]);
+ }
+
+ // Erase the operator
+ model->operators.erase(it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
index 021e9918f2..65132d7d1e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
@@ -19,6 +19,24 @@ limitations under the License.
namespace toco {
+int PadAttributeArray(Array* attribute_array, std::vector<int> pad_values,
+ int mask) {
+ int attribute_dim_count = attribute_array->shape().dims(0);
+ int dim_count = pad_values.size();
+ if (attribute_dim_count < dim_count) {
+ Shape strided_slice_shape = Shape({dim_count});
+ attribute_array->copy_shape(strided_slice_shape);
+ Buffer<ArrayDataType::kInt32>* buffer =
+ &(attribute_array->GetMutableBuffer<ArrayDataType::kInt32>());
+ buffer->data.resize(RequiredBufferSizeForShape(strided_slice_shape));
+ for (int i = attribute_dim_count; i < dim_count; i++) {
+ buffer->data[i] = pad_values[i];
+ mask |= 1 << i;
+ }
+ }
+ return mask;
+}
+
bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
const auto slice_it = model->operators.begin() + op_index;
auto* slice_op = slice_it->get();
@@ -37,52 +55,63 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
return false;
}
- const auto& start_array = model->GetArray(op->inputs[1]);
+ auto& start_array = model->GetArray(op->inputs[1]);
if (!start_array.has_shape()) return false;
if (toco::RequiredBufferSizeForShape(start_array.shape()) > 4) {
// Only 1-4D arrays are supported for now.
return false;
}
- const auto& stop_array = model->GetArray(op->inputs[2]);
+ auto& stop_array = model->GetArray(op->inputs[2]);
if (!stop_array.has_shape()) return false;
- const auto& stride_array = model->GetArray(op->inputs[3]);
+ auto& stride_array = model->GetArray(op->inputs[3]);
if (!stride_array.has_shape()) return false;
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
if (!IsConstantParameterArray(*model, op->inputs[3])) return false;
- op->start_indices = start_array.GetBuffer<ArrayDataType::kInt32>().data;
- op->stop_indices = stop_array.GetBuffer<ArrayDataType::kInt32>().data;
- op->strides = stride_array.GetBuffer<ArrayDataType::kInt32>().data;
+ int num_input_axes = input_array.shape().dimensions_count();
+ int start_indices_size = start_array.shape().dims(0);
+ int stop_indices_size = stop_array.shape().dims(0);
+ int stride_indices_size = stride_array.shape().dims(0);
- CHECK_GE(op->start_indices.size(), 1);
- CHECK_LE(op->start_indices.size(), 4);
- CHECK_EQ(op->stop_indices.size(), op->start_indices.size());
- CHECK_EQ(op->strides.size(), op->stop_indices.size());
+ CHECK_GE(start_indices_size, 1);
+ CHECK_LE(start_indices_size, 4);
+ CHECK_LE(stop_indices_size, 4);
+ CHECK_LE(stride_indices_size, 4);
// The TensorFlow documentation is not explicit on how it handles fewer
// supplied indices than dimensions, but they are accepted. We emulate TF's
// behavior by fully iterating over each omitted dimension.
- int num_input_axes = input_array.shape().dimensions_count();
- CHECK_LE(op->start_indices.size(), num_input_axes)
+ CHECK_LE(start_indices_size, num_input_axes)
<< "StridedSlice op requires no more than " << num_input_axes
<< " start indices";
- CHECK_LE(op->stop_indices.size(), num_input_axes)
+ CHECK_LE(stop_indices_size, num_input_axes)
<< "StridedSlice op requires no more than " << num_input_axes
<< " stop indices";
- CHECK_LE(op->strides.size(), num_input_axes)
+ CHECK_LE(stride_indices_size, num_input_axes)
<< "StridedSlice op requires no more than " << num_input_axes
<< " strides";
- op->PadIndices(num_input_axes);
// Ideally, we would remove the input arrays after they have been resolved.
// However, we must then reconstitute these input arrays for all supported
// export formats. For now, leave the arrays so we don't have to modify our
// exporters. Ideally, we wouldn't have op attributes, and would work directly
// with the input arrays.
+ std::vector<int> begin_pad_values(num_input_axes, 0);
+ op->begin_mask =
+ PadAttributeArray(&start_array, begin_pad_values, op->begin_mask);
+ op->end_mask =
+ PadAttributeArray(&stop_array, input_array.shape().dims(), op->end_mask);
+ std::vector<int> stride_pad_values(num_input_axes, 1);
+ PadAttributeArray(&stride_array, stride_pad_values, 0);
+
+ op->start_indices = start_array.GetBuffer<ArrayDataType::kInt32>().data;
+ op->stop_indices = stop_array.GetBuffer<ArrayDataType::kInt32>().data;
+ op->strides = stride_array.GetBuffer<ArrayDataType::kInt32>().data;
+
return true;
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 1eef173afe..af84c667a7 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -189,6 +189,7 @@ Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
output_array->GetMutableBuffer<ArrayDataType::kFloat>().data;
output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()),
0.f);
+ CHECK_GE(output_float_data.size(), input_flat_size);
if (input_tensor.float_val_size() == 1) {
for (int i = 0; i < input_flat_size; i++) {
output_float_data[i] = input_tensor.float_val(0);
@@ -221,6 +222,7 @@ Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
auto& output_int_data =
output_array->GetMutableBuffer<ArrayDataType::kUint8>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
+ CHECK_GE(output_int_data.size(), input_flat_size);
if (input_tensor.int_val_size()) {
for (int i = 0; i < input_tensor.int_val_size(); i++) {
output_int_data[i] = input_tensor.int_val(i);
@@ -249,6 +251,7 @@ Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
auto& output_int_data =
output_array->GetMutableBuffer<ArrayDataType::kInt32>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
+ CHECK_GE(output_int_data.size(), input_flat_size);
if (input_tensor.int_val_size()) {
for (int i = 0; i < input_tensor.int_val_size(); i++) {
output_int_data[i] = input_tensor.int_val(i);
@@ -277,6 +280,7 @@ Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
auto& output_int_data =
output_array->GetMutableBuffer<ArrayDataType::kInt64>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
+ CHECK_GE(output_int_data.size(), input_flat_size);
if (input_tensor.int64_val_size()) {
for (int i = 0; i < input_tensor.int64_val_size(); i++) {
output_int_data[i] = input_tensor.int64_val(i);
@@ -306,6 +310,7 @@ Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) {
output_array->GetMutableBuffer<ArrayDataType::kBool>().data;
output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()),
false);
+ CHECK_GE(output_bool_data.size(), input_flat_size);
if (input_tensor.bool_val_size()) {
for (int i = 0; i < input_tensor.bool_val_size(); i++) {
output_bool_data[i] = input_tensor.bool_val(i);
@@ -340,13 +345,16 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
output_array->mutable_shape());
if (!status.ok()) return status;
+ if (input_flat_size != input_tensor.string_val_size()) {
+ return Status(false,
+ "Input_content string_val doesn't have the right dimensions "
+ "for this string tensor");
+ }
+
auto& output_string_data =
output_array->GetMutableBuffer<ArrayDataType::kString>().data;
output_string_data.resize(RequiredBufferSizeForShape(output_array->shape()));
- if (input_flat_size != input_tensor.string_val_size()) {
- LOG(FATAL) << "Input_content string_val doesn't have the right "
- "dimensions for this string tensor.";
- }
+ CHECK_GE(output_string_data.size(), input_flat_size);
for (int i = 0; i < input_flat_size; ++i) {
output_string_data[i] = input_tensor.string_val(i);
}
@@ -1240,6 +1248,19 @@ void ConvertLessEqualOperator(const NodeDef& node,
model->operators.emplace_back(op);
}
+void ConvertSinOperator(const NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "Sin");
+ auto* op = new SinOperator;
+ const int num_inputs = GetInputsCount(node, tf_import_flags);
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
void ConvertGreaterOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -2267,6 +2288,8 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
ConvertDynamicStitchOperator(node, tf_import_flags, model);
} else if (node.op() == "RandomUniform") {
ConvertRandomUniform(node, tf_import_flags, model);
+ } else if (node.op() == "Sin") {
+ ConvertSinOperator(node, tf_import_flags, model);
} else if (node.op() == "Select") {
ConvertSelectOperator(node, tf_import_flags, model);
} else {
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 47f8db5978..d878ac54e4 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -78,6 +78,7 @@ enum class OperatorType {
kFloor,
kGather,
kResizeBilinear,
+ kSin,
kSpaceToBatchND,
kStack,
kBatchToSpaceND,
@@ -618,6 +619,17 @@ struct TanhOperator : Operator {
TanhOperator() : Operator(OperatorType::kTanh) {}
};
+// Element-wise Sin operator:
+// x -> Sin(x) = sin(x)
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Sin
+struct SinOperator : Operator {
+ SinOperator() : Operator(OperatorType::kSin) {}
+};
+
// Element-wise addition operator.
//
// Inputs:
@@ -1817,6 +1829,8 @@ class Model {
}
const ArrayMap& GetArrayMap() const { return arrays; }
+ int64 ArithmeticOpsCount() const { return ops_count; }
+
// Optional arrays are used for optional tensors,
// these tensors do not have data, but with reserved names as op inputs.
std::set<string> optional_arrays;
@@ -1833,6 +1847,8 @@ class Model {
std::size_t transient_data_size = 0;
// For code-generation only: required alignment of the transient_data buffer
std::size_t transient_data_alignment = 0;
+ // Arithmatic operations performed in the model.
+ int64 ops_count = 0;
private:
// The associative array mapping names to Array's.
diff --git a/tensorflow/contrib/lite/toco/python/toco.i b/tensorflow/contrib/lite/toco/python/toco.i
index 3787cba4a3..0d2fbdd67b 100644
--- a/tensorflow/contrib/lite/toco/python/toco.i
+++ b/tensorflow/contrib/lite/toco/python/toco.i
@@ -24,9 +24,12 @@ namespace toco {
// Convert a model represented in `input_contents`. `model_flags_proto`
// describes model parameters. `toco_flags_proto` describes conversion
// parameters (see relevant .protos for more information). Returns a string
-// representing the contents of the converted model.
+// representing the contents of the converted model. When extended_return
+// flag is set to true returns a dictionary that contains string representation
+// of the converted model and some statitics like arithmetic ops count.
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
PyObject* toco_flags_proto_txt_raw,
- PyObject* input_contents_txt_raw);
+ PyObject* input_contents_txt_raw,
+ bool extended_return = false);
} // namespace toco \ No newline at end of file
diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.cc b/tensorflow/contrib/lite/toco/python/toco_python_api.cc
index 153c117d17..5b1db852b4 100644
--- a/tensorflow/contrib/lite/toco/python/toco_python_api.cc
+++ b/tensorflow/contrib/lite/toco/python/toco_python_api.cc
@@ -37,7 +37,7 @@ namespace toco {
// sure we input and output bytes rather than unicode strings for Python3.
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
PyObject* toco_flags_proto_txt_raw,
- PyObject* input_contents_txt_raw) {
+ PyObject* input_contents_txt_raw, bool extended_return) {
// Use Python C API to validate and convert arguments. In py3 (bytes),
// in py2 (str).
auto ConvertArg = [&](PyObject* obj, bool* error) {
@@ -78,6 +78,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
Export(toco_flags, *model, toco_flags.allow_custom_ops(),
&output_file_contents_txt);
+ if (extended_return) {
+ PyObject* dict = PyDict_New();
+ PyDict_SetItemString(
+ dict, "flatbuffer",
+ TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(),
+ output_file_contents_txt.size()));
+ PyDict_SetItemString(dict, "arithmetic_ops",
+ PyLong_FromLong(model->ArithmeticOpsCount()));
+ return dict;
+ }
// Convert arguments back to byte (py3) or str (py2)
return TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(),
output_file_contents_txt.size());
diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.h b/tensorflow/contrib/lite/toco/python/toco_python_api.h
index dc378353f7..9af38e937c 100644
--- a/tensorflow/contrib/lite/toco/python/toco_python_api.h
+++ b/tensorflow/contrib/lite/toco/python/toco_python_api.h
@@ -23,10 +23,13 @@ namespace toco {
// Convert a model represented in `input_contents`. `model_flags_proto`
// describes model parameters. `toco_flags_proto` describes conversion
// parameters (see relevant .protos for more information). Returns a string
-// representing the contents of the converted model.
+// representing the contents of the converted model. When extended_return
+// flag is set to true returns a dictionary that contains string representation
+// of the converted model and some statitics like arithmetic ops count.
PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
PyObject* toco_flags_proto_txt_raw,
- PyObject* input_contents_txt_raw);
+ PyObject* input_contents_txt_raw,
+ bool extended_return = false);
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 90e24aa104..5a999439c6 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -926,6 +926,9 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg));
ops.emplace_back(
new SimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect));
+ ops.emplace_back(
+ new SimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice));
+ ops.emplace_back(new SimpleOperator<SinOperator>("SIN", OperatorType::kSin));
return ops;
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index a4fff9974a..89da8538e4 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -117,6 +117,8 @@ TEST_F(OperatorTest, SimpleOperators) {
OperatorType::kTensorFlowLess);
CheckSimpleOperator<NegOperator>("NEG", OperatorType::kNeg);
CheckSimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect);
+ CheckSimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice);
+ CheckSimpleOperator<SinOperator>("SIN", OperatorType::kSin);
}
TEST_F(OperatorTest, BuiltinAdd) {
diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc
index c9c2e9ba01..4867c3a62e 100644
--- a/tensorflow/contrib/lite/toco/tflite/types.cc
+++ b/tensorflow/contrib/lite/toco/tflite/types.cc
@@ -36,6 +36,16 @@ DataBuffer::FlatBufferOffset CopyStringToBuffer(
return builder->CreateVector(dst_data.data(), bytes);
}
+// vector<bool> may be implemented using a bit-set, so we can't just
+// reinterpret_cast, accesing it data as vector<bool> and let flatbuffer
+// CreateVector handle it.
+// Background: https://isocpp.org/blog/2012/11/on-vectorbool
+DataBuffer::FlatBufferOffset CopyBoolToBuffer(
+ const Array& array, flatbuffers::FlatBufferBuilder* builder) {
+ const auto& src_data = array.GetBuffer<ArrayDataType::kBool>().data;
+ return builder->CreateVector(src_data);
+}
+
template <ArrayDataType T>
DataBuffer::FlatBufferOffset CopyBuffer(
const Array& array, flatbuffers::FlatBufferBuilder* builder) {
@@ -86,6 +96,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
return ::tflite::TensorType_UINT8;
case ArrayDataType::kString:
return ::tflite::TensorType_STRING;
+ case ArrayDataType::kBool:
+ return ::tflite::TensorType_BOOL;
default:
// FLOAT32 is filled for unknown data types.
// TODO(ycling): Implement type inference in TF Lite interpreter.
@@ -105,6 +117,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) {
return ArrayDataType::kString;
case ::tflite::TensorType_UINT8:
return ArrayDataType::kUint8;
+ case ::tflite::TensorType_BOOL:
+ return ArrayDataType::kBool;
default:
LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'.";
}
@@ -125,6 +139,8 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
return CopyStringToBuffer(array, builder);
case ArrayDataType::kUint8:
return CopyBuffer<ArrayDataType::kUint8>(array, builder);
+ case ArrayDataType::kBool:
+ return CopyBoolToBuffer(array, builder);
default:
LOG(FATAL) << "Unhandled array data type.";
}
@@ -146,6 +162,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
return CopyStringFromBuffer(buffer, array);
case ::tflite::TensorType_UINT8:
return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
+ case ::tflite::TensorType_BOOL:
+ return CopyBuffer<ArrayDataType::kBool>(buffer, array);
default:
LOG(FATAL) << "Unhandled tensor type.";
}
diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc
index efb849f422..564f303b9b 100644
--- a/tensorflow/contrib/lite/toco/tflite/types_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc
@@ -28,8 +28,7 @@ using flatbuffers::Vector;
// These are types that exist in TF Mini but don't have a correspondence
// in TF Lite.
-static const ArrayDataType kUnsupportedTocoTypes[] = {ArrayDataType::kNone,
- ArrayDataType::kBool};
+static const ArrayDataType kUnsupportedTocoTypes[] = {ArrayDataType::kNone};
// These are TF Lite types for which there is no correspondence in TF Mini.
static const ::tflite::TensorType kUnsupportedTfLiteTypes[] = {
@@ -71,7 +70,8 @@ TEST(DataType, SupportedTypes) {
{ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
{ArrayDataType::kInt32, ::tflite::TensorType_INT32},
{ArrayDataType::kInt64, ::tflite::TensorType_INT64},
- {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}};
+ {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32},
+ {ArrayDataType::kBool, ::tflite::TensorType_BOOL}};
for (auto x : testdata) {
EXPECT_EQ(x.second, DataType::Serialize(x.first));
EXPECT_EQ(x.first, DataType::Deserialize(x.second));
@@ -158,6 +158,13 @@ TEST(DataBuffer, String) {
::testing::ElementsAre("AA", "BBB", "Best. String. Ever."));
}
+TEST(DataBuffer, Bool) {
+ Array recovered =
+ ToFlatBufferAndBack<ArrayDataType::kBool>({true, false, true});
+ EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kBool>().data,
+ ::testing::ElementsAre(true, false, true));
+}
+
TEST(Padding, All) {
EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame));
EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME));
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 58c99051bd..b5531ca2f4 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -86,6 +86,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveConstantRandomUniform);
transformations->Add(new ResolveConstantRange);
transformations->Add(new ResolveConstantReshape);
+ transformations->Add(new ResolveConstantSlice);
transformations->Add(new ResolveConstantStack);
transformations->Add(new ResolveConstantStridedSlice);
transformations->Add(new ResolveConstantTranspose);
@@ -372,6 +373,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
LOG(INFO) << "Estimated count of arithmetic ops: " << 1e-9 * ops_count
<< " billion (note that a multiply-add is counted as 2 ops).";
}
+ model->ops_count = ops_count;
}
void Export(const TocoFlags& toco_flags, const Model& model,
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 1f56fe5c83..1e6314f2dc 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -337,6 +337,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(LogSoftmax)
HANDLE_OPERATORTYPENAME_CASE(Div)
HANDLE_OPERATORTYPENAME_CASE(Tanh)
+ HANDLE_OPERATORTYPENAME_CASE(Sin)
HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll)
HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert)
HANDLE_OPERATORTYPENAME_CASE(ExpandDims)
@@ -986,7 +987,7 @@ void FixOperatorOrdering(Model* model) {
for (auto i : remaining) {
bool can_insert = true;
auto& op = old_operators[i];
- CHECK(op.get());
+ CHECK(op);
for (const auto& input : op->inputs) {
if (!IsConstantParameterArray(*model, input) &&
!arrays_behind_us.count(input)) {
@@ -2073,15 +2074,21 @@ bool ReshapeIsEquivalentToTranspose(const Model& model,
void CheckFinalDataTypesSatisfied(const Model& model) {
for (const auto& array_entry : model.GetArrayMap()) {
const auto& array = *array_entry.second;
+ if (array.data_type == ArrayDataType::kBool) {
+ // Boolean values are never quantized.
+ continue;
+ }
+
// If the final data type is int16, the data type may be float, for example
// after dequantization.
if (array.final_data_type != ArrayDataType::kNone &&
array.final_data_type != ArrayDataType::kInt16) {
- CHECK(array.final_data_type == array.data_type)
+ CHECK(array.data_type == array.final_data_type)
<< "Array \"" << array_entry.first
- << "\" has mis-matching actual and final data types ("
- << ArrayDataTypeName(array.data_type) << ","
- << ArrayDataTypeName(array.final_data_type) << ").";
+ << "\" has mis-matching actual and final data types (data_type="
+ << ArrayDataTypeName(array.data_type)
+ << ", final_data_type=" << ArrayDataTypeName(array.final_data_type)
+ << ").";
}
}
}
diff --git a/tensorflow/contrib/lite/tools/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark_model.cc
index 93c80e0f5e..671ee8359e 100644
--- a/tensorflow/contrib/lite/tools/benchmark_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark_model.cc
@@ -354,7 +354,7 @@ int Main(int argc, char** argv) {
string output_layer_string; // e.g.: output
int num_runs = 50;
string run_delay = "-1.0";
- int num_threads = -1;
+ int num_threads = 1;
string benchmark_name = "";
string output_prefix = "";
int warmup_runs = 1;
diff --git a/tensorflow/contrib/lite/tools/gen_op_registration_main.cc b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc
index 17b514c916..f7df80821f 100644
--- a/tensorflow/contrib/lite/tools/gen_op_registration_main.cc
+++ b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc
@@ -55,7 +55,7 @@ void GenerateFileContent(const std::string& tflite_path,
std::ofstream fout(filename);
fout << "#include \"" << tflite_path << "/model.h\"\n";
- fout << "#include \"" << tflite_path << "/tools/mutable_op_resolver.h\"\n";
+ fout << "#include \"" << tflite_path << "/op_resolver.h\"\n";
fout << "namespace tflite {\n";
fout << "namespace ops {\n";
diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.cc b/tensorflow/contrib/lite/tools/mutable_op_resolver.cc
index 8a921d7c5a..dc9080fd96 100644
--- a/tensorflow/contrib/lite/tools/mutable_op_resolver.cc
+++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.cc
@@ -14,30 +14,4 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
-
-namespace tflite {
-
-TfLiteRegistration* MutableOpResolver::FindOp(
- tflite::BuiltinOperator op) const {
- auto it = builtins_.find(op);
- return it != builtins_.end() ? it->second : nullptr;
-}
-
-TfLiteRegistration* MutableOpResolver::FindOp(const char* op) const {
- auto it = custom_ops_.find(op);
- return it != custom_ops_.end() ? it->second : nullptr;
-}
-
-void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
- TfLiteRegistration* registration) {
- registration->builtin_code = op;
- builtins_.insert(std::make_pair(op, registration));
-}
-
-void MutableOpResolver::AddCustom(const char* name,
- TfLiteRegistration* registration) {
- registration->builtin_code = BuiltinOperator_CUSTOM;
- custom_ops_.insert(std::make_pair(std::string(name), registration));
-}
-
-} // namespace tflite
+// TODO(ycling): Remove this file after removing other dependencies.
diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.h b/tensorflow/contrib/lite/tools/mutable_op_resolver.h
index 573a359c45..c0f2583cdd 100644
--- a/tensorflow/contrib/lite/tools/mutable_op_resolver.h
+++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.h
@@ -15,41 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
#define TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
-#include <map>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/model.h"
-
-// Needed to resolve unordered_set hash on older compilers.
-namespace std {
-template <>
-struct hash<tflite::BuiltinOperator> {
- size_t operator()(const tflite::BuiltinOperator& op) const {
- return std::hash<int>()(op);
- }
-};
-} // namespace std
-
-namespace tflite {
-
-// An OpResolver that is mutable, also used as the op in gen_op_registration.
-// A typical usage:
-// MutableOpResolver resolver;
-// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
-// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
-// InterpreterBuilder(model, resolver)(&interpreter);
-class MutableOpResolver : public OpResolver {
- public:
- MutableOpResolver() {}
- TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override;
- TfLiteRegistration* FindOp(const char* op) const override;
- void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration);
- void AddCustom(const char* name, TfLiteRegistration* registration);
-
- private:
- std::map<int, TfLiteRegistration*> builtins_;
- std::map<std::string, TfLiteRegistration*> custom_ops_;
-};
-
-} // namespace tflite
+#include "tensorflow/contrib/lite/op_resolver.h"
+// MutableOpResolverr is moved into `lite/op_resolver.h`.`
+// TODO(ycling): Remove this file after removing other dependencies.
#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/tools/verifier.cc b/tensorflow/contrib/lite/tools/verifier.cc
index 8818a7dc85..8d3a7a6242 100644
--- a/tensorflow/contrib/lite/tools/verifier.cc
+++ b/tensorflow/contrib/lite/tools/verifier.cc
@@ -246,15 +246,16 @@ bool VerifyOps(const Model& model, const OpResolver& resolver,
}
if (opcode->builtin_code() == BuiltinOperator_CUSTOM) {
- if (!resolver.FindOp(opcode->custom_code()->c_str())) {
- ReportError(error_reporter, "Unsupported custom op: %s",
- opcode->custom_code()->c_str());
+ if (!resolver.FindOp(opcode->custom_code()->c_str(), opcode->version())) {
+ ReportError(error_reporter, "Unsupported custom op: %s, version: %d",
+ opcode->custom_code()->c_str(), opcode->version());
return false;
}
} else {
- if (!resolver.FindOp(opcode->builtin_code())) {
- ReportError(error_reporter, "Unsupported builtin op: %s",
- EnumNameBuiltinOperator(opcode->builtin_code()));
+ if (!resolver.FindOp(opcode->builtin_code(), opcode->version())) {
+ ReportError(error_reporter, "Unsupported builtin op: %s, version: %d",
+ EnumNameBuiltinOperator(opcode->builtin_code()),
+ opcode->version());
return false;
}
}
diff --git a/tensorflow/contrib/lite/tools/verifier.h b/tensorflow/contrib/lite/tools/verifier.h
index b7ce4e8305..a596c650a0 100644
--- a/tensorflow/contrib/lite/tools/verifier.h
+++ b/tensorflow/contrib/lite/tools/verifier.h
@@ -26,12 +26,13 @@ namespace tflite {
class AlwaysTrueResolver : public OpResolver {
public:
AlwaysTrueResolver() {}
- TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override {
+ const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const override {
static TfLiteRegistration null_registration = {nullptr, nullptr, nullptr,
nullptr};
return &null_registration;
}
- TfLiteRegistration* FindOp(const char* op) const override {
+ const TfLiteRegistration* FindOp(const char* op, int version) const override {
static TfLiteRegistration null_registration = {nullptr, nullptr, nullptr,
nullptr};
return &null_registration;
diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc
index 03b93afe3e..8a10e6848a 100644
--- a/tensorflow/contrib/lite/tools/verifier_test.cc
+++ b/tensorflow/contrib/lite/tools/verifier_test.cc
@@ -31,7 +31,6 @@ namespace tflite {
using flatbuffers::FlatBufferBuilder;
using flatbuffers::Offset;
-using flatbuffers::Vector;
// Build single subgraph model.
class TfLiteFlatbufferModelBuilder {
diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD
index e050f3c8d4..4f2c82ca23 100644
--- a/tensorflow/contrib/metrics/BUILD
+++ b/tensorflow/contrib/metrics/BUILD
@@ -77,7 +77,7 @@ py_test(
py_test(
name = "metric_ops_test",
srcs = ["python/ops/metric_ops_test.py"],
- shard_count = 8,
+ shard_count = 16,
srcs_version = "PY2AND3",
tags = ["noasan"], # times out b/63678675
deps = [
diff --git a/tensorflow/contrib/mixed_precision/BUILD b/tensorflow/contrib/mixed_precision/BUILD
new file mode 100644
index 0000000000..3dfb95e0a0
--- /dev/null
+++ b/tensorflow/contrib/mixed_precision/BUILD
@@ -0,0 +1,32 @@
+# Mixed precision training optimizers
+
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "mixed_precision",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/mixed_precision/python:loss_scale_manager",
+ "//tensorflow/contrib/mixed_precision/python:loss_scale_optimizer",
+ ],
+)
diff --git a/tensorflow/contrib/mixed_precision/__init__.py b/tensorflow/contrib/mixed_precision/__init__.py
new file mode 100644
index 0000000000..43e98cdda0
--- /dev/null
+++ b/tensorflow/contrib/mixed_precision/__init__.py
@@ -0,0 +1,34 @@
+# Copyright 2018 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# mixed_precisiond under the License is mixed_precisiond on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Library for mixed precision training."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.mixed_precision.python.loss_scale_manager import *
+from tensorflow.contrib.mixed_precision.python.loss_scale_optimizer import *
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "LossScaleManager",
+ "FixedLossScaleManager",
+ "ExponentialUpdateLossScaleManager",
+ "LossScaleOptimizer",
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/mixed_precision/python/BUILD b/tensorflow/contrib/mixed_precision/python/BUILD
new file mode 100644
index 0000000000..1d769e1614
--- /dev/null
+++ b/tensorflow/contrib/mixed_precision/python/BUILD
@@ -0,0 +1,74 @@
+# Mixed precision training optimizers
+
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "loss_scale_manager",
+ srcs = ["loss_scale_manager.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:variable_scope",
+ ],
+)
+
+py_test(
+ name = "loss_scale_manager_test",
+ size = "small",
+ srcs = ["loss_scale_manager_test.py"],
+ deps = [
+ ":loss_scale_manager",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "loss_scale_optimizer",
+ srcs = ["loss_scale_optimizer.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":loss_scale_manager",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_test(
+ name = "loss_scale_optimizer_test",
+ size = "small",
+ srcs = ["loss_scale_optimizer_test.py"],
+ deps = [
+ ":loss_scale_optimizer",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py
new file mode 100644
index 0000000000..be7377b151
--- /dev/null
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py
@@ -0,0 +1,200 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""LossScaleManager classes for mixed precision training."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import six
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_control_flow_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+
+
+@six.add_metaclass(abc.ABCMeta)
+class LossScaleManager(object):
+ """Abstract loss scale manager class.
+
+ Loss scale managers with a different strategy should subclass this class.
+ Loss scaling is a process that:
+
+ 1) Applies a multiplier on the loss before computing gradients, and
+ 2) Applies the reciprocal of the multiplier on the gradients before they are
+ applied on variables.
+
+ This class is used together with
+ @{tf.contrib.mixed_precision.LossScaleOptimizer} for mixed precision training
+ (float32 variables and float16 ops) on Nvidia GPUs in order to achieve the
+ same model quality as single precision training, with the benefits of
+ potential higher throughput.
+
+ See @{tf.contrib.mixed_precision.LossScaleOptimizer} for more details.
+ """
+
+ @abc.abstractmethod
+ def get_loss_scale(self):
+ """Returns the loss scale as a scalar `float32` tensor."""
+ pass
+
+ @abc.abstractmethod
+ def update_loss_scale(self, finite_grads):
+ """Updates loss scale based on if gradients are finite in current step.
+
+ Args:
+ finite_grads: bool scalar tensor indicating if all gradients are
+ finite (i.e., not inf or nan).
+
+ Returns:
+ An op, when executed updates the loss scale. If eager execution is
+ enabled, does not return anything.
+ """
+ del finite_grads
+ return
+
+
+class FixedLossScaleManager(LossScaleManager):
+ """Loss scale manager with a fixed loss scale.
+
+ The loss scale is not updated for the lifetime of the class.
+ """
+
+ def __init__(self, loss_scale):
+ """Creates the fixed loss scale manager.
+
+ Args:
+ loss_scale: A Python float. Its ideal value varies depending on models to
+ run. Choosing a too small loss_scale might affect model quality; a too
+ big loss_scale might cause inf or nan. There is no single right
+ loss_scale to apply. There is no harm choosing a relatively big number
+ as long as no nan or inf is encountered in training.
+
+ Raises:
+ ValueError: If loss_scale is less than 1.
+ """
+ if loss_scale < 1:
+ raise ValueError("loss scale must be at least 1.")
+ self._loss_scale = ops.convert_to_tensor(loss_scale, dtype=dtypes.float32)
+
+ def get_loss_scale(self):
+ return self._loss_scale
+
+ def update_loss_scale(self, finite_grads):
+ del finite_grads
+ return gen_control_flow_ops.no_op()
+
+
+class ExponentialUpdateLossScaleManager(LossScaleManager):
+ """Loss scale manager uses an exponential update strategy.
+
+ In general, the strategy increases loss scale by a greater-than-one factor
+ after encountering a consecutive series of steps with finite gradients;
+ Similarly, it decreases the loss scale by a factor when the accumulated number
+ of steps with non-finite (nan or inf) gradients are met. An update is not
+ applied if its result is less than 1 or overflows the float32 dynamic range.
+
+ The number of finite and non-finite steps are cleared every time the loss
+ scale is changed. The condition to decrease the loss scale is looser than to
+ increase it since the former does not require the steps to be consecutive.
+ """
+
+ def __init__(self,
+ init_loss_scale,
+ incr_every_n_steps,
+ decr_every_n_nan_or_inf=2,
+ incr_ratio=2,
+ decr_ratio=0.8):
+ """Constructor of exponential-update loss scale manager.
+
+ Args:
+ init_loss_scale: A Python float. The loss scale to use at the beginning.
+ incr_every_n_steps: Increases loss scale every n consecutive steps with
+ finite gradients.
+ decr_every_n_nan_or_inf: Decreases loss scale every n accumulated steps
+ with nan or inf gradients.
+ incr_ratio: The multiplier to use when increasing the loss scale.
+ decr_ratio: The less-than-one-multiplier to use when decreasing the loss
+ scale.
+ """
+ self._incr_every_n_steps = incr_every_n_steps
+ self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
+ self._incr_ratio = incr_ratio
+ self._decr_ratio = decr_ratio
+ self._loss_scale = variable_scope.variable(
+ name="loss_scale",
+ initial_value=ops.convert_to_tensor(init_loss_scale, dtypes.float32),
+ dtype=dtypes.float32,
+ trainable=False)
+ self._num_good_steps = variable_scope.variable(
+ name="good_steps", initial_value=0, dtype=dtypes.int32, trainable=False)
+ self._num_bad_steps = variable_scope.variable(
+ name="bad_steps", initial_value=0, dtype=dtypes.int32, trainable=False)
+
+ def _reset_stats(self):
+ return control_flow_ops.group(
+ state_ops.assign(self._num_good_steps, 0),
+ state_ops.assign(self._num_bad_steps, 0))
+
+ def get_loss_scale(self):
+ """Returns the loss scale."""
+ return self._loss_scale
+
+ def update_loss_scale(self, finite_grads):
+ """Updates loss scale based on if gradients are finite in current step."""
+
+ def update_if_finite_grads():
+ """Branch function when grads are all finite."""
+
+ def incr_loss_scale():
+ new_loss_scale = control_flow_ops.cond(
+ gen_math_ops.is_finite(self._loss_scale * self._incr_ratio),
+ lambda: self._loss_scale * self._incr_ratio,
+ lambda: self._loss_scale)
+ update_op = state_ops.assign(self._loss_scale, new_loss_scale)
+ # When loss_scale is updated, both good and bad steps are reset.
+ return control_flow_ops.group(update_op, self._reset_stats())
+
+ return control_flow_ops.cond(
+ self._num_good_steps + 1 >= self._incr_every_n_steps,
+ incr_loss_scale,
+ lambda: state_ops.assign_add(self._num_good_steps, 1).op)
+
+ def update_if_not_finite_grads():
+ """Branch function when any grad is not finite."""
+
+ def decr_loss_scale():
+ update_op = state_ops.assign(
+ self._loss_scale,
+ gen_math_ops.maximum(1., self._loss_scale * self._decr_ratio))
+ # When loss_scale is updated, both good and bad steps are reset.
+ return control_flow_ops.group(update_op, self._reset_stats())
+
+ def just_update_steps():
+ # When bad_steps is incremented, good_step is reset.
+ return control_flow_ops.group(
+ state_ops.assign_add(self._num_bad_steps, 1),
+ state_ops.assign(self._num_good_steps, 0))
+
+ return control_flow_ops.cond(
+ self._num_bad_steps + 1 >= self._decr_every_n_nan_or_inf,
+ decr_loss_scale, just_update_steps)
+
+ return control_flow_ops.cond(finite_grads, update_if_finite_grads,
+ update_if_not_finite_grads)
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py
new file mode 100644
index 0000000000..480f5f6eaf
--- /dev/null
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py
@@ -0,0 +1,182 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for LossScaleManager classes.."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.mixed_precision.python import loss_scale_manager as lsm_lib
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def _GetExampleIter(inputs):
+ dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
+ return dataset.make_one_shot_iterator()
+
+
+class FixedLossScaleManagerTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_basic(self):
+ itr = _GetExampleIter([True] * 10 + [False] * 10)
+
+ loss_scale = 1000
+ lsm = lsm_lib.FixedLossScaleManager(loss_scale)
+ update_fn = lambda: lsm.update_loss_scale(itr.get_next())
+
+ self.evaluate(variables.global_variables_initializer())
+ if not context.executing_eagerly():
+ update_op = update_fn()
+ for _ in range(10):
+ if context.executing_eagerly():
+ update_fn()
+ else:
+ self.evaluate(update_op)
+ self.assertEqual(loss_scale, self.evaluate(lsm.get_loss_scale()))
+
+
+class ExponentialUpdateLossScaleManagerTest(test.TestCase):
+
+ def _test_helper(self,
+ inputs,
+ expected_outputs,
+ init_loss_scale=1,
+ incr_every_n_step=2,
+ decr_every_n_nan_or_inf=2):
+ ratio = 2
+ lsm = lsm_lib.ExponentialUpdateLossScaleManager(
+ init_loss_scale=init_loss_scale,
+ incr_every_n_steps=incr_every_n_step,
+ decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
+ incr_ratio=ratio,
+ decr_ratio=1. / ratio)
+ itr = _GetExampleIter(inputs)
+ update_fn = lambda: lsm.update_loss_scale(itr.get_next())
+
+ self.evaluate(variables.global_variables_initializer())
+ actual_outputs = []
+
+ if not context.executing_eagerly():
+ update_op = update_fn()
+ for _ in range(len(inputs)):
+ if context.executing_eagerly():
+ update_fn()
+ else:
+ self.evaluate(update_op)
+ actual_outputs.append(self.evaluate(lsm.get_loss_scale()))
+ self.assertEqual(actual_outputs, expected_outputs)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_increase_every_n_steps(self):
+ inputs = [True] * 6
+ expected_outputs = [1, 2, 2, 4, 4, 8]
+ self._test_helper(inputs, expected_outputs)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_keep_increasing_until_capped(self):
+ init_loss_scale = np.finfo(np.float32).max / 4 + 10
+ max_float = np.finfo(np.float32).max
+
+ inputs = [True] * 6
+ # Output is capped the 2nd time it doubles.
+ expected_outputs = [
+ init_loss_scale, init_loss_scale * 2, init_loss_scale * 2, max_float,
+ max_float, max_float
+ ]
+
+ self._test_helper(inputs, expected_outputs, init_loss_scale)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_decrease_every_n_steps(self):
+ inputs = [False] * 6
+ init_loss_scale = 1024
+ expected_outputs = [1024, 512, 512, 256, 256, 128]
+
+ self._test_helper(inputs, expected_outputs, init_loss_scale)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_keep_decreasing_until_one(self):
+ inputs = [False] * 10
+ init_loss_scale = 16
+ expected_outputs = [16, 8, 8, 4, 4, 2, 2, 1, 1, 1]
+
+ self._test_helper(inputs, expected_outputs, init_loss_scale)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_incr_bad_step_clear_good_step(self):
+ inputs = [True, True, True, False, True]
+ expected_outputs = [1, 2, 2, 2, 2]
+ self._test_helper(inputs, expected_outputs)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_incr_good_step_does_not_clear_bad_step(self):
+ inputs = [True, True, True, False, True, False]
+ expected_outputs = [1, 2, 2, 2, 2, 1]
+ self._test_helper(inputs, expected_outputs)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_trigger_loss_scale_update_each_step(self):
+ """Test when incr_every_n_step and decr_every_n_nan_or_inf is 1."""
+ init_loss_scale = 1
+ incr_every_n_step = 1
+ decr_every_n_nan_or_inf = 1
+
+ inputs = [True] * 3 + [False, True, True]
+ expected_outputs = [2, 4, 8, 4, 8, 16]
+
+ self._test_helper(inputs, expected_outputs, init_loss_scale,
+ incr_every_n_step, decr_every_n_nan_or_inf)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_alternating_good_and_bad_gradients_trigger_each_step(self):
+ init_loss_scale = 1
+ incr_every_n_step = 1
+ decr_every_n_nan_or_inf = 1
+
+ inputs = [True, False] * 4 + [True]
+ expected_outputs = [2, 1, 2, 1, 2, 1, 2, 1, 2]
+ self._test_helper(inputs, expected_outputs, init_loss_scale,
+ incr_every_n_step, decr_every_n_nan_or_inf)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_alternating_good_and_bad_gradients_trigger_incr_every_2steps(self):
+ init_loss_scale = 32
+ incr_every_n_step = 2
+ decr_every_n_nan_or_inf = 1
+
+ inputs = [True, False] * 3 + [True]
+ expected_outputs = [32, 16, 16, 8, 8, 4, 4]
+ self._test_helper(inputs, expected_outputs, init_loss_scale,
+ incr_every_n_step, decr_every_n_nan_or_inf)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_random_mix_good_and_bad_gradients(self):
+ init_loss_scale = 4
+ inputs = [
+ False, False, True, True, True, False, True, False, True, True, True,
+ False
+ ]
+ expected_outputs = [4, 2, 2, 4, 4, 4, 4, 2, 2, 4, 4, 4]
+ self._test_helper(inputs, expected_outputs, init_loss_scale)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
new file mode 100644
index 0000000000..e4e5ccc334
--- /dev/null
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
@@ -0,0 +1,166 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Loss scaling optimizer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_control_flow_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import optimizer
+
+
+class LossScaleOptimizer(optimizer.Optimizer):
+ """An optimizer that applies loss scaling in backprop.
+
+ This class is useful for mixed precision training on GPUs (or other potential
+ accelerators), which is an approach to improve compute throughput without loss
+ of model quality.
+
+ The commmon configuration of mixed precision models is the following:
+ * variables are kept in high precision (e.g. float32).
+ * computations are done in lower precision (e.g. float16). variables are
+ casted to lower precision before they're used.
+ * (in training), final gradients are casted back to variable precision and get
+ applied.
+
+ Because computations happen in lower precision, gradients in the backprop pass
+ might underflow in the smaller dynamic range, causing a model to converge at a
+ suboptimal level. This optimizer multiplies the loss by a factor before
+ backprop starts to prevent underflow. Before gradients are applied, they are
+ casted to higher precision and down-scaled by the same factor, so
+ mathematically the variable updates are no different from regular
+ same-precision training.
+
+ See [Nvidia's manual on mixed precision training](
+ https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html)
+ for more details.
+
+ To use loss scale optimizer, one only needs choose a loss scale strategy and
+ wrap a regular optimizer. See examples below.
+
+ ```
+ loss = loss_fn()
+ opt = tf.AdamOptimizer(learning_rate=...)
+
+ # Choose a loss scale manager which decides how to pick the right loss scale
+ # throughout the training process.
+ loss_scale_manger = tf.contrib.mixed_precision.FixedLossScaleManager(5000)
+
+ # Wraps the original optimizer in a LossScaleOptimizer.
+ loss_scale_optimizer = LossScaleOptimizer(opt, loss_scale_manager)
+
+ # Call minimize() on the loss scale optimizer.
+ train_op = loss_scale_optimizer.minimize(loss)
+ ```
+
+ If gradients clipping is applied, one can call
+ `optimizer.compute_gradients()` and `optimizer.apply_gradients()`
+ seperately.
+
+ Notice the following way of using LossScaleOptimizer is not intended. Always
+ use `loss_scale_optimizer.compute_gradients()` to compute gradients instead of
+ `tf.gradients()` if doing mixed precision training.
+
+ ```
+ # The following is a wrong way to use LossScaleOptimizer along with
+ # tf.gradients().
+
+ # Always use loss_scale_optimizer.compute_gradients() to compute grads, or
+ # loss scale is not correctly applied.
+ grads = tf.gradients(loss, ...)
+
+ # Do some custom grad clipping.
+ grads = clip_grads(grads, ...)
+
+ loss_scale_optimizer.apply(grads_and_vars)
+ ```
+ """
+
+ def __init__(self, opt, loss_scale_manager):
+ """Construct a loss scaling optimizer.
+
+ Args:
+ opt: The actual optimizer that will be used to compute and apply the
+ gradients. Must be an implementation of the @{tf.train.Optimizer}
+ interface.
+ loss_scale_manager: A LossScaleManager object.
+ """
+ self._opt = opt
+ self._loss_scale_manager = loss_scale_manager
+
+ def compute_gradients(self,
+ loss,
+ var_list=None,
+ gate_gradients=optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ grad_loss=None):
+ """Compute gradients. See base class @{tf.train.Optimizer}."""
+ loss_scale = self._loss_scale_manager.get_loss_scale()
+ if context.executing_eagerly():
+
+ def scaled_loss():
+ loss_val = loss()
+ return loss_val * math_ops.cast(loss_scale, loss_val.dtype.base_dtype)
+ else:
+ if callable(loss):
+ loss_val = loss()
+ else:
+ loss_val = loss
+ scaled_loss = loss_val * math_ops.cast(loss_scale,
+ loss_val.dtype.base_dtype)
+ grads_and_vars = self._opt.compute_gradients(
+ scaled_loss,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss)
+ return self._down_scale(grads_and_vars, loss_scale)
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ """Apply gradients. See base class @{tf.train.Optimizer}."""
+ grads = [g for (g, _) in grads_and_vars]
+
+ is_finite_grad = []
+ for g in grads:
+ is_finite_grad.append(math_ops.reduce_all(gen_math_ops.is_finite(g)))
+ is_overall_finite = math_ops.reduce_all(is_finite_grad)
+
+ # Only update gradients when all grads are finite.
+ def true_apply_gradients_fn():
+ return self._opt.apply_gradients(grads_and_vars, global_step, name)
+
+ update_vars = control_flow_ops.cond(
+ is_overall_finite, true_apply_gradients_fn, gen_control_flow_ops.no_op)
+ # Potentially adjust gradient scale in case of finite gradients.
+ return control_flow_ops.group(
+ update_vars,
+ self._loss_scale_manager.update_loss_scale(is_overall_finite))
+
+ def _down_scale(self, grads_vars, loss_scale):
+ # Down scale grads by the loss_scale.
+ gv = []
+ inv_loss_scale = gen_math_ops.reciprocal(loss_scale)
+ for g, v in grads_vars:
+ if g is not None:
+ gv.append((g * math_ops.cast(inv_loss_scale, g.dtype.base_dtype), v))
+ else:
+ gv.append((g, v))
+ return gv
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py
new file mode 100644
index 0000000000..dded61ccd5
--- /dev/null
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py
@@ -0,0 +1,216 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for LossScaleOptimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.mixed_precision.python import loss_scale_manager as lsm_lib
+from tensorflow.contrib.mixed_precision.python import loss_scale_optimizer as lso
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import gradient_descent as gd
+
+
+class LossScaleOptimizerTest(test.TestCase):
+
+ def _build_graph(self, lr, init_val, loss_scale_opt_fn=None):
+ x = variable_scope.get_variable(
+ "x", initializer=init_val, dtype=dtypes.float32)
+ c1 = constant_op.constant(1e4, dtype=dtypes.float16)
+ c2 = constant_op.constant(1e-4, dtype=dtypes.float16)
+ c3 = constant_op.constant(1e-4, dtype=dtypes.float16)
+ if context.executing_eagerly():
+ loss = lambda: math_ops.cast(x, dtypes.float16) * c1 * c2 * c3
+ else:
+ loss = math_ops.cast(x, dtypes.float16) * c1 * c2 * c3
+
+ opt = gd.GradientDescentOptimizer(lr)
+ if loss_scale_opt_fn:
+ opt = loss_scale_opt_fn(opt)
+ return x, loss, opt
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_float16_underflow_without_loss_scale(self):
+ lr = 1
+ init_val = 1.
+ x, loss, opt = self._build_graph(lr, init_val)
+
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(opt.minimize(loss, var_list=[x]))
+
+ # Symbolic grad is c1 * c2 * c3 = 1e-4 and actual grad is 0, since in
+ # backprop, c2 * c3 underflows in fp16 range. So variable isn't updated.
+ expected_update = 0
+ symbolic_update = 1e-4 * lr
+ self.assertAllClose(
+ init_val - expected_update,
+ self.evaluate(x),
+ rtol=0,
+ atol=min(symbolic_update, 1e-6))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_float16_with_loss_scale(self):
+ lr = 1.
+ init_val = 1.
+
+ def loss_scale_opt_fn(opt):
+ return lso.LossScaleOptimizer(opt, lsm_lib.FixedLossScaleManager(1e4))
+
+ x, loss, opt = self._build_graph(lr, init_val, loss_scale_opt_fn)
+
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(opt.minimize(loss, var_list=[x]))
+
+ # Symbolic grad is c1 * c2 * c3 = 1e-4 and actual grad is the same, due to
+ # up-scaled loss before backprop starts.
+ expected_update = 1.e-4 * lr
+ self.assertAllClose(
+ init_val - expected_update,
+ self.evaluate(x),
+ rtol=0,
+ atol=min(expected_update, 1e-6))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_compute_gradients_with_loss_scale(self):
+ lr = 1
+ init_val = 1.
+
+ def loss_scale_opt_fn(opt):
+ return lso.LossScaleOptimizer(opt, lsm_lib.FixedLossScaleManager(1e4))
+
+ x, loss, opt = self._build_graph(lr, init_val, loss_scale_opt_fn)
+ grads_and_vars = opt.compute_gradients(loss, var_list=[x])
+
+ self.assertEqual(len(grads_and_vars), 1)
+
+ self.evaluate(variables.global_variables_initializer())
+ g_v = self.evaluate(grads_and_vars[0][0])
+ self.assertAllClose(g_v, 1e-4)
+ self.assertIs(grads_and_vars[0][1], x)
+ # Gradients aren't applied.
+ self.assertAllClose(init_val, self.evaluate(x), rtol=0, atol=1e-6)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_compute_gradients_without_loss_scale(self):
+ lr = 1
+ init_val = 1.
+ x, loss, opt = self._build_graph(lr, init_val)
+ grads_and_vars = opt.compute_gradients(loss, var_list=[x])
+
+ self.assertEqual(len(grads_and_vars), 1)
+ self.evaluate(variables.global_variables_initializer())
+ g_v = self.evaluate(grads_and_vars[0][0])
+ self.assertAllClose(g_v, 0)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_apply_gradients(self):
+
+ x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1])
+ itr = dataset.make_one_shot_iterator()
+
+ lr = 1
+ opt = gd.GradientDescentOptimizer(lr)
+ lsm = lsm_lib.FixedLossScaleManager(1.e4)
+ opt = lso.LossScaleOptimizer(opt, lsm)
+ train_fn = lambda: opt.apply_gradients([(itr.get_next(), x)])
+ if not context.executing_eagerly():
+ train_op = train_fn()
+
+ expected_output = [1, 1, 1 - 0.1]
+ actual_output = []
+
+ self.evaluate(variables.global_variables_initializer())
+ for _ in range(3):
+ # nan or inf is not applied.
+ if context.executing_eagerly():
+ train_fn()
+ else:
+ self.evaluate(train_op)
+ actual_output.append(self.evaluate(x))
+ self.assertAllClose(expected_output, actual_output)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_apply_gradients_loss_scale_is_updated(self):
+
+ class SimpleLossScaleManager(lsm_lib.LossScaleManager):
+ """A simple loss scale manager for easier testing.
+
+ It increments loss scale by 1 if grads are finite, and decreases loss
+ scale by 1 if otherwise.
+ """
+
+ def __init__(self, loss_scale):
+ self._loss_scale = variable_scope.variable(
+ name="loss_scale",
+ initial_value=loss_scale,
+ dtype=dtypes.float32,
+ trainable=False)
+
+ def get_loss_scale(self):
+ return self._loss_scale
+
+ def update_loss_scale(self, if_finite_grads):
+ return control_flow_ops.cond(
+ if_finite_grads, lambda: state_ops.assign_add(self._loss_scale, 1),
+ lambda: state_ops.assign_sub(self._loss_scale, 1))
+
+ x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1])
+ itr = dataset.make_one_shot_iterator()
+
+ lr = 1
+ init_loss_scale = 8
+ opt = gd.GradientDescentOptimizer(lr)
+ lsm = SimpleLossScaleManager(init_loss_scale)
+ opt = lso.LossScaleOptimizer(opt, lsm)
+ train_fn = lambda: opt.apply_gradients([(itr.get_next(), x)])
+ if not context.executing_eagerly():
+ train_op = train_fn()
+
+ self.evaluate(variables.global_variables_initializer())
+
+ expected_loss_scale = [
+ init_loss_scale - 1, init_loss_scale - 2, init_loss_scale - 2 + 1
+ ]
+ expected_output = [1, 1, 1 - 0.1]
+ actual_output = []
+ for i in range(3):
+ # nan or inf is not applied.
+ if context.executing_eagerly():
+ train_fn()
+ else:
+ self.evaluate(train_op)
+ actual_output.append(self.evaluate(x))
+ self.assertAllClose(expected_loss_scale[i],
+ self.evaluate(lsm._loss_scale))
+ self.assertAllClose(expected_output, actual_output)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/mpi/mpi_utils.h b/tensorflow/contrib/mpi/mpi_utils.h
index 4091925fc0..45dc934934 100644
--- a/tensorflow/contrib/mpi/mpi_utils.h
+++ b/tensorflow/contrib/mpi/mpi_utils.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
// Skip MPI C++ bindings support, this matches the usage in other places
#define OMPI_SKIP_MPICXX
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index 9e2858d00f..b1f2e9d860 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -31,14 +31,15 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras._impl.keras.engine import training
from tensorflow.python.keras._impl.keras.layers import core
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import checkpointable
from tensorflow.python.training import checkpointable_utils
@@ -139,8 +140,9 @@ class CheckpointingTests(test.TestCase):
self.evaluate(checkpointable_utils.gather_initializers(
root_checkpointable))
self.evaluate(train_op)
- named_variables, serialized_graph = (
- checkpointable_utils._serialize_object_graph(root_checkpointable))
+ named_variables, serialized_graph, _ = (
+ checkpointable_utils._serialize_object_graph(
+ root_checkpointable, saveables_cache=None))
expected_checkpoint_names = (
# Created in the root node, so no prefix.
"optimizer_step",
@@ -163,24 +165,29 @@ class CheckpointingTests(test.TestCase):
suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
expected_checkpoint_names = [
name + suffix for name in expected_checkpoint_names]
+ # The Dense layers also save get_config() JSON
+ expected_checkpoint_names.extend(
+ ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
+ "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"])
+ named_variables = {v.name: v for v in named_variables}
six.assertCountEqual(self, expected_checkpoint_names,
named_variables.keys())
# Check that we've mapped to the right variable objects (not exhaustive)
self.assertEqual(
- "global_step:0",
- named_variables["optimizer_step" + suffix].name)
+ "global_step",
+ named_variables["optimizer_step" + suffix].full_name)
self.assertEqual(
- "my_model/dense_1/kernel:0",
- named_variables["model/_second/kernel" + suffix].name)
+ "my_model/dense_1/kernel",
+ named_variables["model/_second/kernel" + suffix].full_name)
self.assertEqual(
- "my_model/dense/kernel:0",
- named_variables["model/_named_dense/kernel" + suffix].name)
+ "my_model/dense/kernel",
+ named_variables["model/_named_dense/kernel" + suffix].full_name)
self.assertEqual(
- "beta1_power:0",
- named_variables["optimizer/beta1_power" + suffix].name)
+ "beta1_power",
+ named_variables["optimizer/beta1_power" + suffix].full_name)
self.assertEqual(
- "beta2_power:0",
- named_variables["optimizer/beta2_power" + suffix].name)
+ "beta2_power",
+ named_variables["optimizer/beta2_power" + suffix].full_name)
# Spot check the generated protocol buffers.
self.assertEqual("optimizer",
serialized_graph.nodes[0].children[1].local_name)
@@ -205,7 +212,7 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(
"my_model/dense/kernel/Adam:0",
optimizer.get_slot(
- var=named_variables["model/_named_dense/kernel" + suffix],
+ var=model._named_dense.kernel,
name="m").name)
self.assertEqual(
"model/_named_dense/kernel" + suffix,
@@ -417,16 +424,6 @@ class CheckpointingTests(test.TestCase):
self.evaluate(root.save_counter))
# pylint: enable=cell-var-from-loop
- def _get_checkpoint_name(self, name):
- root = checkpointable.Checkpointable()
- checkpointable_utils.add_variable(
- root, name=name, shape=[1, 2], dtype=dtypes.float64)
- named_variables, _ = checkpointable_utils._serialize_object_graph(root)
- checkpoint_name, = named_variables.keys()
- with ops.name_scope("root/" + checkpoint_name):
- pass # Make sure we can use this as an op name if we prefix it.
- return checkpoint_name
-
def testAnonymousVarsInInit(self):
class Model(training.Model):
@@ -617,6 +614,49 @@ class CheckpointingTests(test.TestCase):
self.assertAllEqual(3., self.evaluate(beta1_power))
+class TemplateTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_checkpointable_save_restore(self):
+
+ def _templated():
+ v = variable_scope.get_variable(
+ "v", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
+ v2 = variable_scope.get_variable(
+ "v2", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
+ return v, v + 1., v2
+
+ save_template = template.make_template("s1", _templated)
+ v1_save, _, v2_save = save_template()
+ optimizer = adam.AdamOptimizer(0.0)
+ save_root = checkpointable_utils.Checkpoint(
+ my_template=save_template, optimizer=optimizer)
+ optimizer.minimize(v1_save.read_value)
+ self.evaluate([v.initializer for v in optimizer.variables()])
+ self.evaluate(v1_save.assign([12.]))
+ self.evaluate(v2_save.assign([14.]))
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ save_path = save_root.save(checkpoint_prefix)
+
+ load_template = template.make_template("s2", _templated)
+ load_optimizer = adam.AdamOptimizer(0.0)
+ load_root = checkpointable_utils.Checkpoint(
+ my_template=load_template, optimizer=load_optimizer)
+ status = load_root.restore(save_path)
+ var, var_plus_one, var2 = load_template()
+ load_optimizer.minimize(var.read_value)
+ self.assertEqual(2, len(load_template._checkpoint_dependencies))
+ self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
+ self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
+ status.assert_consumed().run_restore_ops()
+ self.assertAllEqual([12.], self.evaluate(var))
+ self.assertAllEqual([13.], self.evaluate(var_plus_one))
+ self.assertAllEqual([14.], self.evaluate(var2))
+
+
class CheckpointCompatibilityTests(test.TestCase):
def _initialized_model(self):
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 46bfbb729f..694a3cebd6 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -360,7 +360,16 @@ class _OptimizerV2State(object):
"""
slot_variable = self.get_slot(var=variable, name=slot_name)
if (slot_variable is None and context.executing_eagerly() and
- slot_variable_position.is_simple_variable()):
+ slot_variable_position.is_simple_variable()
+ # Defer slot variable creation if there is an active variable creator
+ # scope. Generally we'd like to eagerly create/restore slot variables
+ # when possible, but this may mean that scopes intended to catch
+ # `variable` also catch its eagerly created slot variable
+ # unintentionally (specifically make_template would add a dependency on
+ # a slot variable if not for this case). Deferring is mostly harmless
+ # (aside from double initialization), and makes variable creator scopes
+ # behave the same way they do when graph building.
+ and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
initializer = checkpointable.CheckpointInitialValue(
checkpoint_position=slot_variable_position)
slot_variable = self.create_slot(
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index 76f695dce0..55479bf5f7 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -475,7 +475,7 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
def _IsValidUnfusedBatchNorm(graph, context):
"""Checks that the output of the unfused batch norm has consumers."""
add_shift = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm/add_1')
+ context + '/BatchNorm/batchnorm_1/add_1')
# Ensure that the output tensor of batch norm has consumers, otherwise this
# is a dangling node and not a match.
return bool(add_shift.outputs[0].consumers())
@@ -568,7 +568,7 @@ def _GetBatchNormParams(graph, context, has_scaling):
op_suffix_mean = '/BatchNorm/moments/Squeeze'
op_suffix_variance = '/BatchNorm/moments/Squeeze_1'
- op_suffix_epsilon = '/BatchNorm/batchnorm/add/y'
+ op_suffix_epsilon = '/BatchNorm/batchnorm_1/add/y'
op_suffix_bn_decay_mean = '/BatchNorm/AssignMovingAvg/decay'
op_suffix_bn_decay_var = '/BatchNorm/AssignMovingAvg_1/decay'
@@ -643,12 +643,12 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
Returns:
A pair of Operations, the first is the original consumer node of the batch
- norm (../BatchNorm/batchnorm/add_1), the second is the consumer node of
+ norm (../BatchNorm/batchnorm_1/add_1), the second is the consumer node of
the folded graph (add_fold).
"""
mul_scale_name = 'mul_1' if has_scaling else 'mul'
mul_scale = graph.get_operation_by_name(context +
- '/BatchNorm/batchnorm/' +
+ '/BatchNorm/batchnorm_1/' +
mul_scale_name)
op_below = mul_scale.inputs[0].op
weights = op_below.inputs[1]
@@ -670,7 +670,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
]
scale_name = 'mul' if has_scaling else 'Rsqrt'
scale = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm/' + scale_name)
+ context + '/BatchNorm/batchnorm_1/' + scale_name)
scale = array_ops.reshape(scale.outputs[0], new_shape,
context + '/scale_reshape')
@@ -698,7 +698,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
[(1, mul_fold.outputs[0])])
add_shift = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm/add_1')
+ context + '/BatchNorm/batchnorm_1/add_1')
corrected_output = conv_or_fc_folded.outputs[0]
if correction_offset is not None:
@@ -886,7 +886,7 @@ def _HasScaling(graph, input_to_ops_map, bn):
Returns:
A boolean indicating whether this batch norm layer has scaling enabled.
"""
- rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt')
+ rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm_1/Rsqrt')
rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op)
return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
index fa5e11b470..bfa9d3bf70 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
@@ -516,13 +516,13 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
if has_scaling:
if fused:
return scope + '/BatchNorm_Fold/mul'
- return scope + '/BatchNorm/batchnorm/mul'
- return scope + '/BatchNorm/batchnorm/Rsqrt'
+ return scope + '/BatchNorm/batchnorm_1/mul'
+ return scope + '/BatchNorm/batchnorm_1/Rsqrt'
def _BathNormBiasName(self, scope, fused):
if fused:
return scope + '/BatchNorm_Fold/bias'
- return scope + '/BatchNorm/batchnorm/sub'
+ return scope + '/BatchNorm/batchnorm_1/sub'
def _WeightInit(self, stddev):
"""Returns a truncated normal variable initializer.
diff --git a/tensorflow/contrib/quantize/python/graph_matcher.py b/tensorflow/contrib/quantize/python/graph_matcher.py
index bacc707a3a..aa3ca991c0 100644
--- a/tensorflow/contrib/quantize/python/graph_matcher.py
+++ b/tensorflow/contrib/quantize/python/graph_matcher.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import abc
+import itertools
class Pattern(object):
@@ -33,7 +34,7 @@ class Pattern(object):
class OpTypePattern(Pattern):
"""A tree pattern that matches TF expressions with certain op types."""
- def __init__(self, op_type, name=None, inputs=None):
+ def __init__(self, op_type, name=None, inputs=None, ordered_inputs=True):
"""Initializes an OpTypePattern.
Args:
@@ -48,16 +49,25 @@ class OpTypePattern(Pattern):
inputs: Optional list of `Pattern`s or strings that specify the
patterns for the inputs of a matching op. If None, this pattern accepts
any inputs of a matching op.
+ ordered_inputs: Defaults to True. If False, will match any op that
+ matches a permutation of the inputs.
+
+ Raises:
+ ValueError: if too many inputs are provided when order_inputs is False.
"""
self._op_type = op_type
self._name = name
if inputs is None:
inputs = []
+ if len(inputs) > 8:
+ raise ValueError(
+ 'Only < 8 inputs are allowed when ordered_inputs is False.')
self._inputs = [
input_pattern
if isinstance(input_pattern, Pattern) else OpTypePattern(input_pattern)
for input_pattern in inputs
]
+ self._ordered_inputs = ordered_inputs
@property
def name(self):
@@ -78,12 +88,23 @@ class OpTypePattern(Pattern):
if len(op.inputs) != len(self._inputs):
return None
- for input_tensor, input_pattern in zip(op.inputs, self._inputs):
- input_match_result = input_pattern.match(input_tensor.op, input_tensor)
- if input_match_result is None:
- return None
- match_result.merge_from(input_match_result)
- return match_result
+ input_patterns_list = [self._inputs]
+ # If order doesn't matter for the inputs, then make sure we match at least
+ # one permutation of the inputs.
+ if not self._ordered_inputs:
+ input_patterns_list = list(itertools.permutations(self._inputs))
+
+ for input_patterns in input_patterns_list:
+ match_failed = False
+ for input_tensor, input_pattern in zip(op.inputs, input_patterns):
+ input_match_result = input_pattern.match(input_tensor.op, input_tensor)
+ if input_match_result is None:
+ match_failed = True
+ break
+ match_result.merge_from(input_match_result)
+ if not match_failed:
+ return match_result
+ return None
class OneofPattern(Pattern):
diff --git a/tensorflow/contrib/quantize/python/graph_matcher_test.py b/tensorflow/contrib/quantize/python/graph_matcher_test.py
index 6d58757218..be741644b6 100644
--- a/tensorflow/contrib/quantize/python/graph_matcher_test.py
+++ b/tensorflow/contrib/quantize/python/graph_matcher_test.py
@@ -22,6 +22,7 @@ from tensorflow.contrib.framework.python import ops as contrib_ops
from tensorflow.contrib.layers.python.layers import initializers
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import graph_matcher
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@@ -163,6 +164,44 @@ class GraphMatcherTest(test_util.TensorFlowTestCase):
self.assertEqual(match_result.get_tensor('slice'), slicing)
self.assertEqual(match_result.get_op('transpose'), transpose.op)
+ def test_ordered_pattern(self):
+ # + +
+ # / \ / \
+ # x y and y x should both match when ordered inputs is False.
+ # Even when x and y are different operations.
+ g = ops.Graph()
+ with g.as_default():
+ x = array_ops.placeholder(dtypes.float32, shape=[], name='x')
+ y = constant_op.constant(1.0, dtype=dtypes.float32)
+ plus = x + y
+
+ add_pattern_a = graph_matcher.OpTypePattern(
+ 'Add', inputs=['Const', 'Placeholder'], ordered_inputs=False)
+ add_pattern_b = graph_matcher.OpTypePattern(
+ 'Add', inputs=['Placeholder', 'Const'], ordered_inputs=False)
+ add_pattern_fail = graph_matcher.OpTypePattern(
+ 'Add', inputs=['Const', 'Placeholder'], ordered_inputs=True)
+ # Both add_pattern_a and add_pattern_b should match the graph since
+ # ordered_input was set False.
+ matcher_a = graph_matcher.GraphMatcher(add_pattern_a)
+ self.assertEqual([
+ match_result.get_op(add_pattern_a)
+ for match_result in matcher_a.match_graph(g)
+ ], [plus.op])
+ matcher_b = graph_matcher.GraphMatcher(add_pattern_b)
+ self.assertEqual([
+ match_result.get_op(add_pattern_b)
+ for match_result in matcher_b.match_graph(g)
+ ], [plus.op])
+ # But if ordered_inputs is True, the inputs list match should fail if not
+ # specified in the right order.
+ matcher_fail = graph_matcher.GraphMatcher(add_pattern_fail)
+ self.assertEqual(
+ len([
+ match_result.get_op(add_pattern_fail)
+ for match_result in matcher_fail.match_graph(g)
+ ]), 0)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py
index 5c0e17dc86..27069444a4 100644
--- a/tensorflow/contrib/quantize/python/quant_ops.py
+++ b/tensorflow/contrib/quantize/python/quant_ops.py
@@ -81,7 +81,8 @@ def LastValueQuantize(inputs,
a tensor containing quantized values.
"""
with variable_scope.variable_scope(
- None, default_name=name_prefix, values=[inputs], reuse=reuse):
+ None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope:
+ scope.set_partitioner(None)
input_shape = inputs.get_shape()
input_dim = len(input_shape)
if per_channel:
@@ -189,7 +190,8 @@ def MovingAvgQuantize(inputs,
a tensor containing quantized values.
"""
with variable_scope.variable_scope(
- None, default_name=name_prefix, values=[inputs], reuse=reuse):
+ None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope:
+ scope.set_partitioner(None)
input_shape = inputs.get_shape()
input_dim = len(input_shape)
if per_channel:
diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py
index 3884679602..c2a8def480 100644
--- a/tensorflow/contrib/quantize/python/quant_ops_test.py
+++ b/tensorflow/contrib/quantize/python/quant_ops_test.py
@@ -23,6 +23,8 @@ from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -73,6 +75,36 @@ class QuantOpsTest(googletest.TestCase):
self.assertGreater(max_value, 0.0)
self.assertLess(max_value, 1.0)
+ def testVariablesNotParitioned_LastValue(self):
+ # Variables added should not use a default partiioner since they are
+ # scalar. There would be a tensorflow error thrown if the partitioner was
+ # respected by the rewrite.
+ with ops.Graph().as_default():
+ with variable_scope.variable_scope(
+ 'part', partitioner=partitioned_variables.fixed_size_partitioner(2)):
+ x = array_ops.placeholder(dtypes.float32, shape=[2])
+ _ = quant_ops.LastValueQuantize(
+ x,
+ init_min=0.0,
+ init_max=0.0,
+ is_training=True,
+ vars_collection=_MIN_MAX_VARS)
+
+ def testVariablesNotParitioned_MovingAvg(self):
+ # Variables added should not use a default partiioner since they are
+ # scalar. There would be a tensorflow error thrown if the partitioner was
+ # respected by the rewrite.
+ with ops.Graph().as_default():
+ with variable_scope.variable_scope(
+ 'part', partitioner=partitioned_variables.fixed_size_partitioner(2)):
+ x = array_ops.placeholder(dtypes.float32, shape=[2])
+ _ = quant_ops.MovingAvgQuantize(
+ x,
+ init_min=0.0,
+ init_max=0.0,
+ is_training=True,
+ vars_collection=_MIN_MAX_VARS)
+
def _GetMinMaxValues(self, sess):
min_max_vars = ops.get_collection(_MIN_MAX_VARS)
self.assertEqual(len(min_max_vars), 2)
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 60616ea749..4e0de24e0e 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -233,37 +233,37 @@ def _FindLayersToQuantize(graph):
weight_identity_pattern, weight_resource_var_pattern,
folded_weight_pattern
])
- ])
+ ],
+ ordered_inputs=False)
folded_bias_mul_pattern = graph_matcher.OpTypePattern(
- 'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_pattern])
+ 'Mul',
+ inputs=[graph_matcher.OpTypePattern('*'), layer_pattern],
+ ordered_inputs=False)
post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
- 'Add', inputs=[folded_bias_mul_pattern,
- graph_matcher.OpTypePattern('*')])
+ 'Add',
+ inputs=[folded_bias_mul_pattern,
+ graph_matcher.OpTypePattern('*')],
+ ordered_inputs=False)
folded_bias_add_pattern = graph_matcher.OpTypePattern(
'Add',
inputs=[
post_layer_op_correction_pattern,
graph_matcher.OpTypePattern('*')
- ])
+ ],
+ ordered_inputs=False)
bias_add_pattern = graph_matcher.OpTypePattern(
- 'Add|BiasAdd', inputs=[layer_pattern, '*'])
+ 'Add|BiasAdd', inputs=[layer_pattern, '*'], ordered_inputs=False)
# The bias can come from the bias add or the folded bias add.
- bypass_pattern_a = graph_matcher.OpTypePattern(
+ bypass_pattern = graph_matcher.OpTypePattern(
'Add',
inputs=[
graph_matcher.OneofPattern(
[bias_add_pattern, folded_bias_add_pattern]), '*'
- ])
- bypass_pattern_b = graph_matcher.OpTypePattern(
- 'Add',
- inputs=[
- '*',
- graph_matcher.OneofPattern(
- [bias_add_pattern, folded_bias_add_pattern])
- ])
+ ],
+ ordered_inputs=False)
# The input to the activation can come from bias add, fold bias add, the
# bypasses.
@@ -273,15 +273,14 @@ def _FindLayersToQuantize(graph):
'|'.join(_ACTIVATION_TYPES) + '|Identity',
inputs=[
graph_matcher.OneofPattern([
- bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a,
- bypass_pattern_b
+ bias_add_pattern,
+ folded_bias_add_pattern,
+ bypass_pattern,
])
])
- post_activation_bypass_pattern_a = graph_matcher.OpTypePattern(
- 'Add', inputs=['*', activation_pattern])
- post_activation_bypass_pattern_b = graph_matcher.OpTypePattern(
- 'Add', inputs=[activation_pattern, '*'])
+ post_activation_bypass_pattern = graph_matcher.OpTypePattern(
+ 'Add', inputs=['*', activation_pattern], ordered_inputs=False)
# The order of the following matching blocks is very important. Since matches
# aren't guaranteed to be disjoint, we structure matches from largest to
@@ -297,10 +296,7 @@ def _FindLayersToQuantize(graph):
# to ensure we don't match only the first part of this layer, missing the
# post activation bypass node.
post_activation_bypass_layer_matcher = graph_matcher.GraphMatcher(
- graph_matcher.OneofPattern([
- post_activation_bypass_pattern_a,
- post_activation_bypass_pattern_b,
- ]))
+ post_activation_bypass_pattern)
for match_result in post_activation_bypass_layer_matcher.match_graph(graph):
layer_op = match_result.get_op(layer_pattern)
weight_tensor = match_result.get_tensor(weight_identity_pattern)
@@ -312,14 +308,9 @@ def _FindLayersToQuantize(graph):
bias_add_op = match_result.get_op(bias_add_pattern)
if bias_add_op is None:
bias_add_op = match_result.get_op(folded_bias_add_pattern)
- bypass_op = match_result.get_op(bypass_pattern_a)
- if bypass_op is None:
- bypass_op = match_result.get_op(bypass_pattern_b)
+ bypass_op = match_result.get_op(bypass_pattern)
post_activation_bypass_op = match_result.get_op(
- post_activation_bypass_pattern_a)
- if post_activation_bypass_op is None:
- post_activation_bypass_op = match_result.get_op(
- post_activation_bypass_pattern_b)
+ post_activation_bypass_pattern)
if layer_op not in matched_layer_set:
matched_layer_set.add(layer_op)
layer_matches.append(
@@ -340,9 +331,7 @@ def _FindLayersToQuantize(graph):
bias_add_op = match_result.get_op(bias_add_pattern)
if bias_add_op is None:
bias_add_op = match_result.get_op(folded_bias_add_pattern)
- bypass_op = match_result.get_op(bypass_pattern_a)
- if bypass_op is None:
- bypass_op = match_result.get_op(bypass_pattern_b)
+ bypass_op = match_result.get_op(bypass_pattern)
if layer_op not in matched_layer_set:
matched_layer_set.add(layer_op)
layer_matches.append(
diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py
index cf55da2723..a42bbca611 100644
--- a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py
+++ b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py
@@ -385,7 +385,7 @@ class ReceptiveFieldTest(test.TestCase):
effective_stride_y, effective_padding_x, effective_padding_y) = (
receptive_field.compute_receptive_field_from_graph_def(
graph_def, input_node, output_node,
- ['Dropout/dropout/random_uniform']))
+ ['Dropout/dropout_1/random_uniform']))
self.assertEqual(receptive_field_x, 3)
self.assertEqual(receptive_field_y, 3)
self.assertEqual(effective_stride_x, 4)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index d41fc0b3ac..e512e8db53 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -483,7 +483,12 @@ class RNNCellTest(test.TestCase):
base_cell = rnn_cell_impl.GRUCell(3)
g, m_new = base_cell(x, m)
variable_scope.get_variable_scope().reuse_variables()
- g_res, m_new_res = rnn_cell_impl.ResidualWrapper(base_cell)(x, m)
+ wrapper_object = rnn_cell_impl.ResidualWrapper(base_cell)
+ (name, dep), = wrapper_object._checkpoint_dependencies
+ self.assertIs(dep, base_cell)
+ self.assertEqual("cell", name)
+
+ g_res, m_new_res = wrapper_object(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run([g, g_res, m_new, m_new_res], {
x: np.array([[1., 1., 1.]]),
@@ -526,7 +531,12 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = array_ops.zeros([1, 3])
- cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), "/cpu:14159")
+ wrapped = rnn_cell_impl.GRUCell(3)
+ cell = rnn_cell_impl.DeviceWrapper(wrapped, "/cpu:14159")
+ (name, dep), = cell._checkpoint_dependencies
+ self.assertIs(dep, wrapped)
+ self.assertEqual("cell", name)
+
outputs, _ = cell(x, m)
self.assertTrue("cpu:14159" in outputs.device.lower())
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index ba4933ddf7..be99a5d67a 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
@@ -142,6 +143,47 @@ class TestStateSaver(object):
self.saved_state[name] = state
return array_ops.identity(state)
+ @property
+ def batch_size(self):
+ return self._batch_size
+
+ @property
+ def state_size(self):
+ return self._state_size
+
+
+class TestStateSaverWithCounters(TestStateSaver):
+ """Class wrapper around TestStateSaver.
+
+ A dummy class used for testing of static_state_saving_rnn. It helps test if
+ save_state and state functions got called same number of time when we
+ evaluate output of rnn cell and state or either of them separately. It
+ inherits from the TestStateSaver and adds the counters for calls of functions.
+ """
+
+ def __init__(self, batch_size, state_size):
+ super(TestStateSaverWithCounters, self).__init__(batch_size, state_size)
+ self._num_state_calls = variables_lib.Variable(0)
+ self._num_save_state_calls = variables_lib.Variable(0)
+
+ def state(self, name):
+ with ops_lib.control_dependencies(
+ [state_ops.assign_add(self._num_state_calls, 1)]):
+ return super(TestStateSaverWithCounters, self).state(name)
+
+ def save_state(self, name, state):
+ with ops_lib.control_dependencies([state_ops.assign_add(
+ self._num_save_state_calls, 1)]):
+ return super(TestStateSaverWithCounters, self).save_state(name, state)
+
+ @property
+ def num_state_calls(self):
+ return self._num_state_calls
+
+ @property
+ def num_save_state_calls(self):
+ return self._num_save_state_calls
+
class RNNTest(test.TestCase):
@@ -186,6 +228,9 @@ class RNNTest(test.TestCase):
cell = Plus1RNNCell()
full_dropout_cell = rnn_cell.DropoutWrapper(
cell, input_keep_prob=1e-12, seed=0)
+ (name, dep), = full_dropout_cell._checkpoint_dependencies
+ self.assertIs(dep, cell)
+ self.assertEqual("cell", name)
batch_size = 2
input_size = 5
max_length = 8
@@ -1792,13 +1837,40 @@ class StateSaverRNNTest(test.TestCase):
self._seed = 23489
np.random.seed(self._seed)
- def _testScope(self, factory, prefix="prefix", use_outer_scope=True):
+ def _factory(self, scope, state_saver):
+ num_units = state_saver.state_size // 2
+ batch_size = state_saver.batch_size
+ input_size = 5
+ max_length = 8
+ initializer = init_ops.random_uniform_initializer(
+ -0.01, 0.01, seed=self._seed)
+ cell = rnn_cell.LSTMCell(
+ num_units,
+ use_peepholes=False,
+ initializer=initializer,
+ state_is_tuple=False)
+ inputs = max_length * [
+ array_ops.zeros(dtype=dtypes.float32, shape=(batch_size, input_size))
+ ]
+ out, state = rnn.static_state_saving_rnn(
+ cell,
+ inputs,
+ state_saver=state_saver,
+ state_name="save_lstm",
+ scope=scope)
+ return out, state, state_saver
+
+ def _testScope(self, prefix="prefix", use_outer_scope=True):
+ num_units = 3
+ batch_size = 2
+ state_saver = TestStateSaver(batch_size, 2 * num_units)
+
with self.test_session(use_gpu=True, graph=ops_lib.Graph()):
if use_outer_scope:
with variable_scope.variable_scope(prefix) as scope:
- factory(scope)
+ self._factory(scope=scope, state_saver=state_saver)
else:
- factory(prefix)
+ self._factory(scope=prefix, state_saver=state_saver)
variables_lib.global_variables_initializer()
# check that all the variables names starts
@@ -1813,34 +1885,46 @@ class StateSaverRNNTest(test.TestCase):
self.assertEqual(len(scope_vars), len(all_vars))
def testStateSaverRNNScope(self):
- num_units = 3
- input_size = 5
- batch_size = 2
- max_length = 8
+ self._testScope(use_outer_scope=True)
+ self._testScope(use_outer_scope=False)
+ self._testScope(prefix=None, use_outer_scope=False)
- def factory(scope):
- initializer = init_ops.random_uniform_initializer(
- -0.01, 0.01, seed=self._seed)
- state_saver = TestStateSaver(batch_size, 2 * num_units)
- cell = rnn_cell.LSTMCell(
- num_units,
- use_peepholes=False,
- initializer=initializer,
- state_is_tuple=False)
- inputs = max_length * [
- array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
- ]
- return rnn.static_state_saving_rnn(
- cell,
- inputs,
- state_saver=state_saver,
- state_name="save_lstm",
- scope=scope)
+ def testStateSaverCallsSaveState(self):
+ """Test that number of calls to state and save_state is equal.
- self._testScope(factory, use_outer_scope=True)
- self._testScope(factory, use_outer_scope=False)
- self._testScope(factory, prefix=None, use_outer_scope=False)
+ Test if the order of actual evaluating or skipping evaluation of out,
+ state tensors, which are the output tensors from static_state_saving_rnn,
+ have influence on number of calls to save_state and state methods of
+ state_saver object (the number of calls should be same.)
+ """
+ num_units = 3
+ batch_size = 2
+ state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units)
+ out, state, state_saver = self._factory(scope=None, state_saver=state_saver)
+
+ with self.test_session() as sess:
+ sess.run(variables_lib.global_variables_initializer())
+ sess.run(variables_lib.local_variables_initializer())
+
+ _, _, num_state_calls, num_save_state_calls = sess.run([
+ out,
+ state,
+ state_saver.num_state_calls,
+ state_saver.num_save_state_calls])
+ self.assertEqual(num_state_calls, num_save_state_calls)
+
+ _, num_state_calls, num_save_state_calls = sess.run([
+ out,
+ state_saver.num_state_calls,
+ state_saver.num_save_state_calls])
+ self.assertEqual(num_state_calls, num_save_state_calls)
+
+ _, num_state_calls, num_save_state_calls = sess.run([
+ state,
+ state_saver.num_state_calls,
+ state_saver.num_save_state_calls])
+ self.assertEqual(num_state_calls, num_save_state_calls)
class GRUTest(test.TestCase):
diff --git a/tensorflow/contrib/signal/python/ops/window_ops.py b/tensorflow/contrib/signal/python/ops/window_ops.py
index 50094010dc..59e67e8ba4 100644
--- a/tensorflow/contrib/signal/python/ops/window_ops.py
+++ b/tensorflow/contrib/signal/python/ops/window_ops.py
@@ -47,7 +47,7 @@ def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None):
Raises:
ValueError: If `dtype` is not a floating point type.
- [hann]: https://en.wikipedia.org/wiki/Window_function#Hann_window
+ [hann]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
"""
return _raised_cosine_window(name, 'hann_window', window_length, periodic,
dtype, 0.5, 0.5)
@@ -72,7 +72,7 @@ def hamming_window(window_length, periodic=True, dtype=dtypes.float32,
Raises:
ValueError: If `dtype` is not a floating point type.
- [hamming]: https://en.wikipedia.org/wiki/Window_function#Hamming_window
+ [hamming]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
"""
return _raised_cosine_window(name, 'hamming_window', window_length, periodic,
dtype, 0.54, 0.46)
diff --git a/tensorflow/contrib/slim/README.md b/tensorflow/contrib/slim/README.md
index 746b955642..f2bb458848 100644
--- a/tensorflow/contrib/slim/README.md
+++ b/tensorflow/contrib/slim/README.md
@@ -909,3 +909,8 @@ slim.evaluation.evaluation_loop(
## Authors
Sergio Guadarrama and Nathan Silberman
+
+## Citation
+"TensorFlow-Slim: a lightweight library for defining, training and evaluating complex models in TensorFlow"
+S. Guadarrama, N. Silberman, 2016.
+https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
index f2d31dc8db..d877831fce 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
@@ -102,7 +102,7 @@ class BoundingBox(ItemHandler):
"""An ItemHandler that concatenates a set of parsed Tensors to Bounding Boxes.
"""
- def __init__(self, keys=None, prefix=None):
+ def __init__(self, keys=None, prefix=''):
"""Initialize the bounding box handler.
Args:
diff --git a/tensorflow/contrib/sparsemax/BUILD b/tensorflow/contrib/sparsemax/BUILD
index b729fff261..d7ba754f70 100644
--- a/tensorflow/contrib/sparsemax/BUILD
+++ b/tensorflow/contrib/sparsemax/BUILD
@@ -38,7 +38,7 @@ py_library(
cuda_py_tests(
name = "sparsemax_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/sparsemax_test.py"],
additional_deps = [
":sparsemax_py",
diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py
index 99ced53e11..d22b80ac88 100644
--- a/tensorflow/contrib/summary/summary.py
+++ b/tensorflow/contrib/summary/summary.py
@@ -21,6 +21,7 @@ from @{tf.summary.merge_all} to @{tf.summary.FileWriter}.
To use with eager execution enabled, write your code as follows:
+```python
global_step = tf.train.get_or_create_global_step()
summary_writer = tf.contrib.summary.create_file_writer(
train_dir, flush_millis=10000)
@@ -30,9 +31,11 @@ with summary_writer.as_default(), tf.contrib.summary.always_record_summaries():
tf.contrib.summary.scalar("loss", my_loss)
# In this case every call to tf.contrib.summary.scalar will generate a record
# ...
+```
To use it with graph execution, write your code as follows:
+```python
global_step = tf.train.get_or_create_global_step()
summary_writer = tf.contrib.summary.create_file_writer(
train_dir, flush_millis=10000)
@@ -53,7 +56,7 @@ with tf.Session(...) as sess:
while not_done_training:
sess.run([train_op, tf.contrib.summary.all_summary_ops()])
# ...
-
+```
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 675f0b1fd6..7a8a71ac7f 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -67,6 +67,7 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = [
":trt_logging",
+ ":trt_plugins",
] + if_tensorrt([
"@local_config_tensorrt//:nv_infer",
]) + tf_custom_op_library_additional_deps(),
@@ -86,6 +87,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":trt_logging",
+ ":trt_plugins",
":trt_resources",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib_proto_parsing",
@@ -232,6 +234,7 @@ tf_cuda_library(
],
deps = [
":segment",
+ ":trt_plugins",
":trt_logging",
":trt_resources",
"//tensorflow/core/grappler/clusters:cluster",
@@ -263,7 +266,6 @@ cc_library(
"segment/segment.h",
"segment/union_find.h",
],
- linkstatic = 1,
deps = [
"//tensorflow/core:graph",
"//tensorflow/core:lib_proto_parsing",
@@ -286,6 +288,46 @@ tf_cc_test(
],
)
+# Library for the plugin factory
+tf_cuda_library(
+ name = "trt_plugins",
+ srcs = [
+ "plugin/trt_plugin.cc",
+ "plugin/trt_plugin_factory.cc",
+ "plugin/trt_plugin_utils.cc",
+ ],
+ hdrs = [
+ "plugin/trt_plugin.h",
+ "plugin/trt_plugin_factory.h",
+ "plugin/trt_plugin_utils.h",
+ ],
+ deps = [
+ "//tensorflow/core:framework_lite",
+ "//tensorflow/core:platform_base",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
+
+tf_cuda_cc_test(
+ name = "trt_plugin_factory_test",
+ size = "small",
+ srcs = ["plugin/trt_plugin_factory_test.cc"],
+ tags = [
+ "manual",
+ "notap",
+ ],
+ deps = [
+ ":trt_plugins",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ] + if_tensorrt([
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
+
py_test(
name = "tf_trt_integration_test",
srcs = ["test/tf_trt_integration_test.py"],
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 4df54a749f..b7b26cfb1c 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include <list>
#include <map>
@@ -77,7 +78,8 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) {
// TODO(ben,jie): ...
};
// LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h)
- return candidate_ops.count(node->type_string());
+ return (candidate_ops.count(node->type_string()) ||
+ PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string()));
}
void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index be559d30e0..32b211dcd1 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include <algorithm>
#include <list>
@@ -240,35 +241,49 @@ class TFAttrs {
return attrs_.at(key);
}
template <typename T>
- T get(string key) const;
+ T get(const string& key) const;
template <typename T>
- T get(string key, const T& default_value) const {
+ T get(const string& key, const T& default_value) const {
return attrs_.count(key) ? this->get<T>(key) : default_value;
}
+ std::vector<string> GetAllAttrKey() {
+ std::vector<string> attr_list;
+ for (const auto& attr_item : attrs_) {
+ attr_list.emplace_back(attr_item.first);
+ }
+ return attr_list;
+ }
+
private:
typedef std::map<string, tensorflow::AttrValue const*> AttrMap;
AttrMap attrs_;
};
template <>
-string TFAttrs::get<string>(string key) const {
+string TFAttrs::get<string>(const string& key) const {
return this->at(key)->s();
}
template <>
-std::vector<int> TFAttrs::get<std::vector<int>>(string key) const {
+std::vector<int> TFAttrs::get<std::vector<int>>(const string& key) const {
auto attr = this->at(key)->list().i();
return std::vector<int>(attr.begin(), attr.end());
}
template <>
-std::vector<string> TFAttrs::get<std::vector<string>>(string key) const {
+std::vector<float> TFAttrs::get<std::vector<float>>(const string& key) const {
+ auto attr = this->at(key)->list().f();
+ return std::vector<float>(attr.begin(), attr.end());
+}
+
+template <>
+std::vector<string> TFAttrs::get<std::vector<string>>(const string& key) const {
auto attr = this->at(key)->list().s();
return std::vector<string>(attr.begin(), attr.end());
}
template <>
-nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(string key) const {
+nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(const string& key) const {
auto values = this->get<std::vector<int>>(key);
nvinfer1::Dims dims;
dims.nbDims = values.size();
@@ -278,24 +293,25 @@ nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(string key) const {
}
template <>
-nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(string key) const {
+nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
return trt_dtype;
}
template <>
-tensorflow::DataType TFAttrs::get<tensorflow::DataType>(string key) const {
+tensorflow::DataType TFAttrs::get<tensorflow::DataType>(
+ const string& key) const {
return this->at(key)->type();
}
template <>
-float TFAttrs::get<float>(string key) const {
+float TFAttrs::get<float>(const string& key) const {
return this->at(key)->f();
}
template <>
-bool TFAttrs::get<bool>(string key) const {
+bool TFAttrs::get<bool>(const string& key) const {
return this->at(key)->b();
}
@@ -424,6 +440,7 @@ using OpConverter =
class Converter {
std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
std::unordered_map<string, OpConverter> op_registry_;
+ OpConverter plugin_converter_;
nvinfer1::INetworkDefinition* trt_network_;
std::list<std::vector<uint8_t>> temp_bufs_;
tensorflow::tensorrt::TRTWeightStore* weight_store_;
@@ -490,13 +507,17 @@ class Converter {
std::vector<TRT_TensorOrWeights> inputs;
TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs));
string op = node_def.op();
- if (!op_registry_.count(op)) {
- return tensorflow::errors::Unimplemented(
- "No converter registered for op: " + op);
- }
- OpConverter op_converter = op_registry_.at(op);
std::vector<TRT_TensorOrWeights> outputs;
- TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
+ if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) {
+ TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs));
+ } else {
+ if (!op_registry_.count(op)) {
+ return tensorflow::errors::Unimplemented(
+ "No converter registered for op: " + op);
+ }
+ OpConverter op_converter = op_registry_.at(op);
+ TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
+ }
for (size_t i = 0; i < outputs.size(); ++i) {
TRT_TensorOrWeights output = outputs.at(i);
// TODO(jie): tf protobuf seems to be omitting the :0 suffix
@@ -1173,6 +1194,45 @@ tensorflow::Status BinaryTensorOpTensor(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertPlugin(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ const std::vector<TRT_TensorOrWeights>& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ // prepare input
+ std::vector<nvinfer1::ITensor*> all_inputs;
+ for (auto input : inputs) {
+ all_inputs.emplace_back(const_cast<nvinfer1::ITensor*>(input.tensor()));
+ }
+
+ // plugin is owned by PluginFactory
+ // TODO(jie): destroy plugins later (resource management)
+ PluginTensorRT* plugin =
+ PluginFactoryTensorRT::GetInstance()->CreatePlugin(node_def.op());
+
+ // passing attributes
+ // TODO(jie): support more general attribute
+ TFAttrs attrs(node_def);
+ auto attr_key_vector = attrs.GetAllAttrKey();
+ for (auto attr_key : attr_key_vector) {
+ // TODO(jie): support only list of float for toy example here.
+ auto data = attrs.get<std::vector<float>>(attr_key);
+ size_t size_data = data.size() * sizeof(float);
+ if (!plugin->SetAttribute(attr_key, static_cast<void*>(data.data()),
+ size_data)) {
+ return tensorflow::errors::InvalidArgument("plugin SetAttribute failed");
+ }
+ }
+
+ nvinfer1::IPluginLayer* layer = ctx.network()->addPlugin(
+ &all_inputs[0], static_cast<int>(inputs.size()), *plugin);
+
+ for (int i = 0; i < layer->getNbOutputs(); i++) {
+ nvinfer1::ITensor* output_tensor = layer->getOutput(i);
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ }
+ return tensorflow::Status::OK();
+}
+
tensorflow::Status ConvertPlaceholder(
Converter& ctx, const tensorflow::NodeDef& node_def,
const std::vector<TRT_TensorOrWeights>& inputs,
@@ -2073,6 +2133,8 @@ void Converter::register_op_converters() {
op_registry_["Reshape"] = ConvertReshape;
op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm;
op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm;
+
+ plugin_converter_ = ConvertPlugin;
}
} // namespace
diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD
new file mode 100644
index 0000000000..a89cf3ab8b
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD
@@ -0,0 +1,118 @@
+# Description:
+# Example for plugin support in TensorRT(http://developer.nvidia.com/tensorrt)
+# through TensorFlow integration. Targeting TensorRT 3.0.4
+# APIs are meant to change while upgrading TRT.
+# add init_py into pip package BUILD dependency to install it.
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_custom_op_library",
+ "tf_custom_op_library_additional_deps",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+)
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+load(
+ "@local_config_tensorrt//:build_defs.bzl",
+ "if_tensorrt",
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["inc_op"],
+)
+
+tf_gen_op_wrapper_py(
+ name = "inc_op",
+ deps = [":inc_op_op_lib"],
+)
+
+tf_custom_op_library(
+ name = "_inc_op.so",
+ srcs = [
+ "inc_op_kernel.h",
+ "inc_op_plugin.cc",
+ "inc_op_plugin.h",
+ "ops/inc_op.cc",
+ ],
+ gpu_srcs = [
+ "inc_op_kernel.h",
+ "inc_op_kernel.cu.cc",
+ ],
+ deps = [
+ "//tensorflow/contrib/tensorrt:trt_plugins",
+ "//tensorflow/core:framework_lite",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
+
+tf_kernel_library(
+ name = "inc_op_plugin_kernel",
+ srcs = ["inc_op_plugin.cc"],
+ hdrs = [
+ "inc_op_kernel.h",
+ "inc_op_plugin.h",
+ ],
+ gpu_srcs = [
+ "inc_op_kernel.h",
+ "inc_op_kernel.cu.cc",
+ ],
+ deps = [
+ "//tensorflow/contrib/tensorrt:trt_plugins",
+ "//tensorflow/core:stream_executor_headers_lib",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]) + tf_custom_op_library_additional_deps(),
+)
+
+tf_custom_op_py_library(
+ name = "inc_op_loader",
+ srcs = ["inc_op.py"],
+ dso = [
+ ":_inc_op.so",
+ ],
+ kernels = [
+ ":inc_op_op_lib",
+ ":inc_op_plugin_kernel",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:resources",
+ ],
+)
+
+py_library(
+ name = "init_py",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":inc_op",
+ ":inc_op_loader",
+ ],
+)
+
+cuda_py_test(
+ name = "plugin_test",
+ size = "small",
+ srcs = ["plugin_test.py"],
+ additional_deps = [
+ ":init_py",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/contrib/tensorrt:init_py",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:tf_optimizer",
+ ],
+ tags = [
+ "manual",
+ "noguitar",
+ "notap",
+ ],
+)
diff --git a/tensorflow/python/keras/estimator/__init__.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py
index 6f931f4158..363edab2e8 100644
--- a/tensorflow/python/keras/estimator/__init__.py
+++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,15 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ==============================================================================
-"""Keras estimator API."""
+# =============================================================================
+"""Import custom op for plugin and register it in plugin factory registry."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.keras._impl.keras.estimator import model_to_estimator
+from tensorflow.contrib.tensorrt.custom_plugin_examples import inc_op as import_inc_op_so
+from tensorflow.contrib.tensorrt.custom_plugin_examples.ops import gen_inc_op
-del absolute_import
-del division
-del print_function
+inc_op = gen_inc_op.inc_plugin_trt
diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py
new file mode 100644
index 0000000000..a007c3f54e
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py
@@ -0,0 +1,32 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Loader for the custom inc_op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import platform
+
+if platform.system() != "Windows":
+ # pylint: disable=g-import-not-at-top
+ from tensorflow.contrib.util import loader
+ from tensorflow.python.platform import resource_loader
+ # pylint: enable=g-import-not-at-top
+
+ _inc_op = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_inc_op.so"))
+else:
+ raise RuntimeError("Windows not supported")
diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc
new file mode 100644
index 0000000000..988b35f74f
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h"
+
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "cuda/include/cuda_runtime_api.h"
+#include "tensorflow/core/platform/stream_executor.h"
+
+namespace tensorflow {
+namespace tensorrt {
+
+__global__ void VecInc(const float* vec, float inc, float* dest, int n) {
+ int i = blockDim.x * blockIdx.x + threadIdx.x;
+ if (i < n) dest[i] = vec[i] + inc;
+}
+
+void IncrementKernel(const float* d_input, float inc, float* d_output,
+ int count, cudaStream_t stream) {
+ int threads_per_block = 256;
+ int blocks_per_grid = (count + threads_per_block - 1) / threads_per_block;
+
+ VecInc<<<threads_per_block, blocks_per_grid, 0, stream>>>(d_input, inc,
+ d_output, count);
+}
+
+// Note: this kernel definition is not needed in the plugin_test rule, but it is
+// required for correctness of the TF program, i.e. if not using plugin or when
+// run with trt optimization pass, the test should work.
+class IncPluginTRT : public OpKernel {
+ public:
+ explicit IncPluginTRT(OpKernelConstruction* context) : OpKernel(context) {
+ std::vector<float> inc_list;
+ OP_REQUIRES_OK(context, context->GetAttr("inc", &inc_list));
+ OP_REQUIRES(context, inc_list.size() == 1,
+ errors::InvalidArgument(
+ "The increment list should contain single element."));
+ inc_ = inc_list[0];
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_tensor = context->input(0);
+ const TensorShape& input_shape = input_tensor.shape();
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input_shape, &output_tensor));
+ const cudaStream_t* stream = CHECK_NOTNULL(
+ reinterpret_cast<const cudaStream_t*>(context->op_device_context()
+ ->stream()
+ ->implementation()
+ ->CudaStreamMemberHack()));
+ IncrementKernel(input_tensor.flat<float>().data(), inc_,
+ output_tensor->flat<float>().data(),
+ input_shape.num_elements(), *stream);
+ }
+
+ private:
+ float inc_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("IncPluginTRT").Device(DEVICE_GPU), IncPluginTRT);
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // GOOGLE_TENSORRT
diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h
new file mode 100644
index 0000000000..c35955e105
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_KERNEL_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_KERNEL_H_
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "cuda/include/cuda_runtime_api.h"
+
+namespace tensorflow {
+namespace tensorrt {
+
+void IncrementKernel(const float* d_input, float inc, float* d_output,
+ int count, cudaStream_t stream);
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_KERNEL_H_
diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc
new file mode 100644
index 0000000000..8d4c893af5
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc
@@ -0,0 +1,86 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h"
+
+#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+namespace tensorrt {
+
+const char* kPluginName = "IncPluginTRT";
+
+IncOpPlugin* CreateIncPlugin() { return new IncOpPlugin(); }
+
+IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) {
+ return new IncOpPlugin(buffer, length);
+}
+
+REGISTER_TRT_PLUGIN(kPluginName, CreateIncPluginDeserialize, CreateIncPlugin);
+
+IncOpPlugin::IncOpPlugin() : plugin_name_(kPluginName) {}
+
+IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length)
+ : PluginTensorRT(serialized_data, length), plugin_name_(kPluginName) {
+ // account for the consumed pointer.
+ size_t consumed_data = PluginTensorRT::getSerializationSize();
+ assert(length - consumed_data >= sizeof(float));
+ const char* buffer = reinterpret_cast<const char*>(serialized_data);
+ SetAttribute("inc", buffer + consumed_data, sizeof(float));
+}
+
+bool IncOpPlugin::SetAttribute(const string& key, const void* ptr,
+ const size_t size) {
+ if (strcmp(key.c_str(), "inc") == 0 && size == sizeof(float)) {
+ StoreAttribute(key, ptr, size); // save the attribute to own the data;
+ inc_ = *static_cast<const float*>(ptr);
+ return true;
+ }
+ return false;
+}
+
+bool IncOpPlugin::GetAttribute(const string& key, const void** ptr,
+ size_t* size) const {
+ const auto& iter = attr_map_.find(key);
+ if (iter != attr_map_.end()) {
+ *ptr = iter->second.data();
+ *size = iter->second.size();
+ return true;
+ }
+ return false;
+}
+
+int IncOpPlugin::enqueue(int batch_size, const void* const* inputs,
+ void** outputs, void*, cudaStream_t stream) {
+ int count = 1;
+ for (int i = 0; i < input_dim_list_[0].nbDims; i++) {
+ count *= input_dim_list_[0].d[i];
+ }
+ count *= batch_size;
+ const float* input = reinterpret_cast<const float*>(inputs[0]);
+ float* output = reinterpret_cast<float*>(outputs[0]);
+ IncrementKernel(input, inc_, output, count, stream);
+ return 0;
+}
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // GOOGLE_TENSORRT
diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h
new file mode 100644
index 0000000000..189e9c939b
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_PLUGIN_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_PLUGIN_H_
+
+#include <cassert>
+#include <cstring>
+
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace tensorrt {
+
+class IncOpPlugin : public PluginTensorRT {
+ public:
+ IncOpPlugin();
+
+ IncOpPlugin(const void* serialized_data, size_t length);
+
+ const string& GetPluginName() const override { return plugin_name_; };
+
+ bool Finalize() override { return true; };
+
+ bool SetAttribute(const string& key, const void* ptr,
+ const size_t size) override;
+
+ bool GetAttribute(const string& key, const void** ptr,
+ size_t* size) const override;
+
+ int getNbOutputs() const override { return 1; }
+
+ nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
+ int num_input_dims) override {
+ assert(index == 0);
+ assert(num_input_dims == 1);
+ return inputs[0];
+ }
+
+ // use configure to setup input dimensions
+ void configure(const nvinfer1::Dims* inputs, int num_inputs,
+ const nvinfer1::Dims* outputs, int num_outputs,
+ int max_batch_size) override {
+ assert(num_inputs == 1);
+ PluginTensorRT::configure(inputs, num_inputs, outputs, num_outputs,
+ max_batch_size);
+ }
+
+ int initialize() override { return 0; }
+
+ void terminate() override {}
+
+ size_t getWorkspaceSize(int max_batch_size) const override { return 0; }
+
+ int enqueue(int batch_size, const void* const* inputs, void** outputs,
+ void* workspace, cudaStream_t stream) override;
+
+ size_t getSerializationSize() override {
+ return PluginTensorRT::getSerializationSize() + sizeof(float);
+ }
+
+ void serialize(void* buffer) override {
+ // Serialize parent data.
+ PluginTensorRT::serialize(buffer);
+ // Incremented buffer after parent serialization.
+ buffer =
+ static_cast<char*>(buffer) + PluginTensorRT::getSerializationSize();
+ std::memcpy(buffer, &inc_, sizeof(float));
+ buffer = static_cast<char*>(buffer) + sizeof(float);
+ }
+
+ protected:
+ float inc_;
+ nvinfer1::Dims dim_;
+
+ private:
+ const string plugin_name_;
+};
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_PLUGIN_H_
diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc
new file mode 100644
index 0000000000..d0eb0d299d
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+
+REGISTER_OP("IncPluginTRT")
+ .Attr("inc: list(float)")
+ .Input("input: float32")
+ .Output("output: float32")
+ .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ });
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // GOOGLE_TENSORRT
diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py
new file mode 100644
index 0000000000..bc4d270bec
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py
@@ -0,0 +1,95 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to show usage of TensorRT custom op & plugin."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy
+
+from tensorflow.contrib import tensorrt
+from tensorflow.contrib.tensorrt import custom_plugin_examples
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+
+
+class TrtPluginTest(test_util.TensorFlowTestCase):
+
+ def _get_plugin_graph_def(self):
+ """Create a simple graph and return its graph_def."""
+ g = ops.Graph()
+ with g.as_default():
+ a = array_ops.placeholder(
+ dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input")
+ relu = nn.relu(a, "relu")
+ v = nn_ops.max_pool(
+ relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
+
+ # insert custom_op in the graph
+ v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test")
+
+ v *= 2.0
+ v = nn.relu(v)
+ v = nn.relu(v)
+ array_ops.squeeze(v, name="output")
+ return g.as_graph_def()
+
+ def _run_graph(self, gdef, dumm_inp):
+ """Run given graphdef once."""
+ gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
+ ops.reset_default_graph()
+ g = ops.Graph()
+ with g.as_default():
+ inp, out = importer.import_graph_def(
+ graph_def=gdef, return_elements=["input", "output"])
+ inp = inp.outputs[0]
+ out = out.outputs[0]
+
+ with session.Session(
+ config=config_pb2.ConfigProto(gpu_options=gpu_options),
+ graph=g) as sess:
+ val = sess.run(out, {inp: dumm_inp})
+ return val
+
+ def testIncOpPlugin(self):
+ inp_dims = (5, 24, 24, 2)
+ dummy_input = numpy.ones(inp_dims).astype(numpy.float32)
+ orig_graph = self._get_plugin_graph_def() # graph with plugin node
+
+ # trigger conversion.
+ # plugin nodes have been registered during import, converter will be able to
+ # create corresponding plugin layer during conversion.
+ trt_graph = tensorrt.create_inference_graph(
+ input_graph_def=orig_graph,
+ outputs=["output"],
+ max_batch_size=inp_dims[0],
+ max_workspace_size_bytes=1 << 25,
+ precision_mode="FP32",
+ minimum_segment_size=2)
+ o2 = self._run_graph(trt_graph, dummy_input)
+ self.assertEqual(35, o2.reshape([-1])[0])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 5c5b2e3c07..9ac8047944 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
@@ -59,7 +60,8 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
infer->setGpuAllocator(allocator_.get());
#endif
trt_engine_ptr_.reset(infer->deserializeCudaEngine(
- serialized_engine_.c_str(), serialized_engine_.size(), nullptr));
+ serialized_engine_.c_str(), serialized_engine_.size(),
+ PluginFactoryTensorRT::GetInstance()));
trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
// Runtime is safe to delete after engine creation
infer->destroy();
diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h
index 7f3544f8cf..96ccacb791 100644
--- a/tensorflow/contrib/tensorrt/log/trt_logger.h
+++ b/tensorflow/contrib/tensorrt/log/trt_logger.h
@@ -28,7 +28,7 @@ namespace tensorrt {
// Logger for GIE info/warning/errors
class Logger : public nvinfer1::ILogger {
public:
- Logger(string name = "DefaultLogger") : name_(name){};
+ Logger(string name = "DefaultLogger") : name_(name) {}
void log(nvinfer1::ILogger::Severity severity, const char* msg) override;
private:
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc
new file mode 100644
index 0000000000..062f86e8bb
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc
@@ -0,0 +1,106 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h"
+#include <cassert>
+#include <cstring>
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+namespace tensorrt {
+
+PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) {
+ const char* buffer = static_cast<const char*>(serialized_data);
+ size_t op_name_char_count = *reinterpret_cast<const size_t*>(buffer);
+ buffer += sizeof(size_t);
+ buffer += op_name_char_count;
+
+ size_t count = *reinterpret_cast<const size_t*>(buffer);
+ buffer += sizeof(size_t);
+
+ for (int i = 0; i < count; i++) {
+ nvinfer1::Dims dim;
+ std::memcpy(&(dim.nbDims), buffer, sizeof(dim.nbDims));
+ buffer += sizeof(dim.nbDims);
+ std::memcpy(dim.d, buffer, sizeof(dim.d));
+ buffer += sizeof(dim.d);
+ std::memcpy(dim.type, buffer, sizeof(dim.type));
+ buffer += sizeof(dim.type);
+ input_dim_list_.emplace_back(dim);
+ }
+}
+
+void PluginTensorRT::configure(const nvinfer1::Dims* inputs, int num_inputs,
+ const nvinfer1::Dims* outputs, int num_outputs,
+ int max_batch_size) {
+ for (int index = 0; index < num_inputs; index++) {
+ nvinfer1::Dims dim;
+ dim.nbDims = inputs[index].nbDims;
+ for (int i = 0; i < dim.nbDims; i++) {
+ dim.d[i] = inputs[index].d[i];
+ dim.type[i] = inputs[index].type[i];
+ }
+ input_dim_list_.emplace_back(dim);
+ }
+}
+
+size_t PluginTensorRT::getSerializationSize() {
+ nvinfer1::Dims dim;
+ return sizeof(size_t) + GetPluginName().size() +
+ sizeof(input_dim_list_.size()) + sizeof(dim.nbDims) + sizeof(dim.d) +
+ sizeof(dim.type);
+}
+
+void PluginTensorRT::serialize(void* serialized_data) {
+ size_t op_name_size = GetPluginName().size();
+ char* buffer = static_cast<char*>(serialized_data);
+ std::memcpy(buffer, &op_name_size, sizeof(size_t));
+ buffer += sizeof(size_t);
+
+ std::memcpy(buffer, GetPluginName().data(), op_name_size);
+ buffer += op_name_size;
+
+ auto list_size = input_dim_list_.size();
+ std::memcpy(buffer, &list_size, sizeof(input_dim_list_.size()));
+ buffer += sizeof(input_dim_list_.size());
+
+ for (int i = 0; i < input_dim_list_.size(); i++) {
+ auto dim = input_dim_list_[i];
+ std::memcpy(buffer, &(dim.nbDims), sizeof(dim.nbDims));
+ buffer += sizeof(dim.nbDims);
+ std::memcpy(buffer, dim.d, sizeof(dim.d));
+ buffer += sizeof(dim.d);
+ std::memcpy(buffer, dim.type, sizeof(dim.type));
+ buffer += sizeof(dim.type);
+ }
+}
+
+bool PluginTensorRT::StoreAttribute(const string& key, const void* ptr,
+ const size_t size) {
+ if (attr_map_.count(key) != 0) return false;
+
+ attr_map_.emplace(key, std::vector<char>(size));
+ std::memcpy(attr_map_[key].data(), ptr, size);
+ return true;
+}
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // GOOGLE_TENSORRT
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h
new file mode 100644
index 0000000000..754920b60c
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h
@@ -0,0 +1,74 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_
+
+#include <iostream>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace tensorrt {
+
+// A wrapper class for TensorRT plugin
+// User application should inherit from this class to write custom kernels.
+// Allows user to insert custom op in TensorRT engine
+// To register plugin in converter, user should also register custom
+// PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT
+class PluginTensorRT : public nvinfer1::IPlugin {
+ public:
+ PluginTensorRT() {}
+ PluginTensorRT(const void* serialized_data, size_t length);
+
+ virtual const string& GetPluginName() const = 0;
+
+ virtual bool Finalize() = 0;
+
+ virtual bool SetAttribute(const string& key, const void* ptr,
+ const size_t size) = 0;
+ virtual bool GetAttribute(const string& key, const void** ptr,
+ size_t* size) const = 0;
+
+ void configure(const nvinfer1::Dims* inputs, int num_inputs,
+ const nvinfer1::Dims* outputs, int num_outputs,
+ int max_batch_size) override;
+
+ virtual bool StoreAttribute(const string& key, const void* ptr,
+ const size_t size);
+
+ size_t getSerializationSize() override;
+
+ void serialize(void* buffer) override;
+
+ protected:
+ std::unordered_map<string, std::vector<char> > attr_map_;
+
+ std::vector<nvinfer1::Dims> input_dim_list_;
+};
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc
new file mode 100644
index 0000000000..2bc591484d
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc
@@ -0,0 +1,78 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+namespace tensorrt {
+
+PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name,
+ const void* serial_data,
+ size_t serial_length) {
+ size_t parsed_byte = 0;
+ // extract op_name from serial_data
+ string encoded_op_name =
+ ExtractOpName(serial_data, serial_length, &parsed_byte);
+
+ if (!IsPlugin(encoded_op_name)) {
+ return nullptr;
+ }
+
+ tensorflow::mutex_lock lock(instance_m_);
+ auto plugin_ptr =
+ plugin_registry_[encoded_op_name].first(serial_data, serial_length);
+ owned_plugins_.emplace_back(plugin_ptr);
+
+ return plugin_ptr;
+}
+
+PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string& op_name) {
+ if (!IsPlugin(op_name)) return nullptr;
+
+ tensorflow::mutex_lock lock(instance_m_);
+ auto plugin_ptr = plugin_registry_[op_name].second();
+ owned_plugins_.emplace_back(plugin_ptr);
+
+ return plugin_ptr;
+}
+
+bool PluginFactoryTensorRT::RegisterPlugin(
+ const string& op_name, PluginDeserializeFunc deserialize_func,
+ PluginConstructFunc construct_func) {
+ if (IsPlugin(op_name)) return false;
+
+ tensorflow::mutex_lock lock(instance_m_);
+ auto ret = plugin_registry_.emplace(
+ op_name, std::make_pair(deserialize_func, construct_func));
+
+ return ret.second;
+}
+
+void PluginFactoryTensorRT::DestroyPlugins() {
+ tensorflow::mutex_lock lock(instance_m_);
+ for (auto& owned_plugin_ptr : owned_plugins_) {
+ owned_plugin_ptr.release();
+ }
+ owned_plugins_.clear();
+}
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // GOOGLE_TENSORRT
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h
new file mode 100644
index 0000000000..bbae9fb65c
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_
+
+#include <memory>
+#include <unordered_map>
+
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace tensorrt {
+
+class PluginFactoryTensorRT : public nvinfer1::IPluginFactory {
+ public:
+ // TODO(aaroey): this static method has to be inlined to make the singleton a
+ // unique global symbol. Find a way to fix it.
+ static PluginFactoryTensorRT* GetInstance() {
+ static PluginFactoryTensorRT* factory_instance =
+ new PluginFactoryTensorRT();
+ return factory_instance;
+ }
+
+ // Deserialization method
+ PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data,
+ size_t serial_length) override;
+
+ // Plugin construction, PluginFactoryTensorRT owns the plugin.
+ PluginTensorRT* CreatePlugin(const string& op_name);
+
+ bool RegisterPlugin(const string& op_name,
+ PluginDeserializeFunc deserialize_func,
+ PluginConstructFunc construct_func);
+
+ bool IsPlugin(const string& op_name) {
+ return plugin_registry_.find(op_name) != plugin_registry_.end();
+ }
+
+ size_t CountOwnedPlugins() { return owned_plugins_.size(); }
+
+ void DestroyPlugins();
+
+ protected:
+ std::unordered_map<string,
+ std::pair<PluginDeserializeFunc, PluginConstructFunc>>
+ plugin_registry_;
+
+ // TODO(jie): Owned plugin should be associated with different sessions;
+ // should really hand ownership of plugins to resource management;
+ std::vector<std::unique_ptr<PluginTensorRT>> owned_plugins_;
+ tensorflow::mutex instance_m_;
+};
+
+class TrtPluginRegistrar {
+ public:
+ TrtPluginRegistrar(const string& name, PluginDeserializeFunc deserialize_func,
+ PluginConstructFunc construct_func) {
+ auto factory = PluginFactoryTensorRT::GetInstance();
+ QCHECK(factory->RegisterPlugin(name, deserialize_func, construct_func))
+ << "Failed to register plugin: " << name;
+ }
+};
+
+#define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \
+ REGISTER_TRT_PLUGIN_UNIQ_HELPER(__COUNTER__, name, deserialize_func, \
+ construct_func)
+#define REGISTER_TRT_PLUGIN_UNIQ_HELPER(ctr, name, deserialize_func, \
+ construct_func) \
+ REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func)
+#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \
+ static ::tensorflow::tensorrt::TrtPluginRegistrar trt_plugin_registrar##ctr \
+ TF_ATTRIBUTE_UNUSED = ::tensorflow::tensorrt::TrtPluginRegistrar( \
+ name, deserialize_func, construct_func)
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc
new file mode 100644
index 0000000000..129bdcdbc2
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc
@@ -0,0 +1,125 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
+
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace test {
+
+class StubPlugin : public PluginTensorRT {
+ public:
+ static const char* kPluginName;
+
+ StubPlugin() : plugin_name_(kPluginName) {}
+
+ StubPlugin(const void* serialized_data, size_t length)
+ : PluginTensorRT(serialized_data, length) {}
+
+ const string& GetPluginName() const override { return plugin_name_; }
+
+ bool Finalize() override { return true; }
+
+ bool SetAttribute(const string& key, const void* ptr,
+ const size_t size) override {
+ return true;
+ }
+
+ bool GetAttribute(const string& key, const void** ptr,
+ size_t* size) const override {
+ return true;
+ }
+
+ int getNbOutputs() const override { return 1; }
+
+ nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs,
+ int nbInputDims) override {
+ return inputs[0];
+ }
+
+ int initialize() override { return 0; }
+
+ void terminate() override {}
+
+ size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }
+
+ int enqueue(int batch_size, const void* const* inputs, void** outputs,
+ void* workspace, cudaStream_t stream) override {
+ return 0;
+ }
+
+ private:
+ const string plugin_name_;
+};
+
+const char* StubPlugin::kPluginName = "StubPlugin";
+
+StubPlugin* CreateStubPlugin() { return new StubPlugin(); }
+
+StubPlugin* CreateStubPluginDeserialize(const void* serialized_data,
+ size_t length) {
+ return new StubPlugin(serialized_data, length);
+}
+
+class TrtPluginFactoryTest : public ::testing::Test {
+ public:
+ bool RegisterStubPlugin() {
+ if (PluginFactoryTensorRT::GetInstance()->IsPlugin(
+ StubPlugin::kPluginName)) {
+ return true;
+ }
+ return PluginFactoryTensorRT::GetInstance()->RegisterPlugin(
+ StubPlugin::kPluginName, CreateStubPluginDeserialize, CreateStubPlugin);
+ }
+};
+
+TEST_F(TrtPluginFactoryTest, Registration) {
+ EXPECT_FALSE(
+ PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName));
+ EXPECT_TRUE(RegisterStubPlugin());
+
+ ASSERT_TRUE(
+ PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName));
+}
+
+TEST_F(TrtPluginFactoryTest, CreationDeletion) {
+ EXPECT_TRUE(RegisterStubPlugin());
+ ASSERT_TRUE(
+ PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName));
+
+ PluginFactoryTensorRT::GetInstance()->DestroyPlugins();
+ ASSERT_TRUE(PluginFactoryTensorRT::GetInstance()->CreatePlugin(
+ StubPlugin::kPluginName));
+ ASSERT_EQ(1, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins());
+ PluginFactoryTensorRT::GetInstance()->DestroyPlugins();
+ ASSERT_EQ(0, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins());
+}
+
+} // namespace test
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc
new file mode 100644
index 0000000000..a8f60886c0
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h"
+#include <cassert>
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+namespace tensorrt {
+
+string ExtractOpName(const void* serial_data, size_t serial_length,
+ size_t* incremental) {
+ size_t op_name_char_count = *static_cast<const size_t*>(serial_data);
+ *incremental = sizeof(size_t) + op_name_char_count;
+
+ assert(serial_length >= *incremental);
+
+ const char* buffer = static_cast<const char*>(serial_data) + sizeof(size_t);
+ string op_name(buffer, op_name_char_count);
+
+ return op_name;
+}
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // GOOGLE_TENSORRT
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h
new file mode 100644
index 0000000000..274ce42fec
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_
+
+#include <functional>
+
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h"
+#include "tensorflow/core/platform/types.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace tensorrt {
+
+typedef std::function<PluginTensorRT*(const void*, size_t)>
+ PluginDeserializeFunc;
+
+typedef std::function<PluginTensorRT*(void)> PluginConstructFunc;
+
+// TODO(jie): work on error handling here
+string ExtractOpName(const void* serial_data, size_t serial_length,
+ size_t* incremental);
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
index 8b475177bc..f36495f6b6 100644
--- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include <string>
#include <vector>
@@ -33,7 +34,8 @@ tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) {
TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine));
nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
- serialized_engine.c_str(), serialized_engine.size(), nullptr);
+ serialized_engine.c_str(), serialized_engine.size(),
+ tensorrt::PluginFactoryTensorRT::GetInstance());
int num_batch = -1;
std::vector<::tensorflow::DataType> input_type;
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index d2746032a0..e4963596d3 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -110,6 +110,7 @@ py_test(
"no_pip_gpu", # b/63391119
"nomsan", # Takes too long to run.
"notsan", # b/67865658
+ "optonly", # Takes too long to run without optimization.
],
deps = [
":ar_model",
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 706742ca28..983455f63d 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -68,15 +68,16 @@ class TimeSeriesRegressorTest(test.TestCase):
eval_input_fn = input_pipeline.RandomWindowInputFn(
input_pipeline.NumpyReader(features), shuffle_seed=3, num_threads=1,
batch_size=16, window_size=16)
- first_estimator.train(input_fn=train_input_fn, steps=5)
+ first_estimator.train(input_fn=train_input_fn, steps=1)
first_loss_before_fit = first_estimator.evaluate(
input_fn=eval_input_fn, steps=1)["loss"]
- first_estimator.train(input_fn=train_input_fn, steps=50)
+ self.assertAllEqual([], first_loss_before_fit.shape)
+ first_estimator.train(input_fn=train_input_fn, steps=1)
first_loss_after_fit = first_estimator.evaluate(
input_fn=eval_input_fn, steps=1)["loss"]
- self.assertLess(first_loss_after_fit, first_loss_before_fit)
+ self.assertAllEqual([], first_loss_after_fit.shape)
second_estimator = estimator_fn(model_dir, exogenous_feature_columns)
- second_estimator.train(input_fn=train_input_fn, steps=2)
+ second_estimator.train(input_fn=train_input_fn, steps=1)
whole_dataset_input_fn = input_pipeline.WholeDatasetInputFn(
input_pipeline.NumpyReader(features))
whole_dataset_evaluation = second_estimator.evaluate(
diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto
index 840a43913b..1f249de314 100644
--- a/tensorflow/contrib/tpu/profiler/op_profile.proto
+++ b/tensorflow/contrib/tpu/profiler/op_profile.proto
@@ -60,6 +60,11 @@ message Metrics {
// - it does not reveal the peak core FLOPS of the hardware
double flops = 2;
+ // The VMEM bandwidth used to load operands from HBM, as a fraction of
+ // thereotical VMEM bandwidth on the specific hardware.
+ double memory_bandwidth = 3;
+
double raw_time = 11; // Elapsed core-time in picoseconds.
double raw_flops = 12; // Total floating-point operations performed.
+ double raw_bytes_accessed = 13; // Total bytes accessed (include read/write).
}
diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py
index faf677a81d..3e91e2df32 100644
--- a/tensorflow/contrib/tpu/python/tpu/session_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/session_support.py
@@ -292,14 +292,21 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
if self._saver:
return self._saver
- savers = ops.get_collection(ops.GraphKeys.SAVERS)[0]
+ savers = ops.get_collection(ops.GraphKeys.SAVERS)
if not savers:
return None
if not isinstance(savers, list):
return savers
- assert len(savers) == 1, 'Only one saver supported.'
+ if len(savers) > 1:
+ logging.error(
+ 'Multiple savers in the SAVERS collection. On-demand checkpointing '
+ 'will be disabled. Pass an explicit `saver` to the constructor to '
+ 'override this behavior.'
+ )
+ return None
+
return savers[0]
def after_run(self, run_context, run_values):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 4d7bc6a5a6..aefdfb9ac7 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -44,6 +44,10 @@ class _TPUContext(object):
information commonly required by TPU computation, such as TPU device names,
TPU hosts, shard batch size, etc.
+ if eval_on_tpu is False, then execution of eval on TPU is disabled.
+ if eval_on_tpu is True, but use_tpu is False, a warning is issued,
+ and TPU execution is disabled for all modes.
+
N.B. As `mode` is not immutable state in Estimator, but essential to
distinguish between TPU training and evaluation, a common usage for
_TPUContext with `mode` is as follows:
@@ -55,12 +59,17 @@ class _TPUContext(object):
"""
def __init__(self, config, train_batch_size, eval_batch_size,
- predict_batch_size, use_tpu):
+ predict_batch_size, use_tpu, eval_on_tpu=True):
self._config = config
self._train_batch_size = train_batch_size
self._eval_batch_size = eval_batch_size
self._predict_batch_size = predict_batch_size
self._use_tpu = use_tpu
+ logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu)
+ if not use_tpu and eval_on_tpu:
+ logging.warning('eval_on_tpu ignored because use_tpu is False.')
+
+ self._eval_on_tpu = eval_on_tpu
self._model_parallelism_enabled = (
use_tpu and config.tpu_config.computation_shape)
self._mode = None
@@ -246,6 +255,10 @@ class _TPUContext(object):
if not self._use_tpu:
return True
+ if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu:
+ logging.info('_is_running_on_cpu: eval_on_tpu disabled')
+ return True
+
if mode != model_fn_lib.ModeKeys.PREDICT:
return False
@@ -345,6 +358,7 @@ class _TPUContext(object):
@property
def tpu_host_placement_function(self):
"""Returns the TPU host place function."""
+
master = self.master_job
def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name
@@ -503,7 +517,7 @@ class _OneCoreTPUContext(_TPUContext):
def _get_tpu_context(config, train_batch_size, eval_batch_size,
- predict_batch_size, use_tpu):
+ predict_batch_size, use_tpu, eval_on_tpu):
"""Returns an instance of `_TPUContext`."""
if (config.tpu_config.num_shards == 1 and
@@ -515,4 +529,4 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu)
return _TPUContext(config, train_batch_size, eval_batch_size,
- predict_batch_size, use_tpu)
+ predict_batch_size, use_tpu, eval_on_tpu)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index a624eceed9..ed5db7369f 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -46,7 +46,6 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -68,6 +67,7 @@ from tensorflow.python.training import evaluation
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.training import training_util
+from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@@ -76,12 +76,13 @@ _ZERO_LOSS = 0.
_TPU_ESTIMATOR = 'tpu_estimator'
_ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop'
_BATCH_SIZE_KEY = 'batch_size'
+_CTX_KEY = 'context'
_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
_ONE_GIGABYTE = 1024 * 1024 * 1024
_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'
_TPU_TRAIN_OP = '_tpu_train_op'
-_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
+_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY]
# TODO(b/65703635): Flip the value and remove all dead code. Currently, this is
@@ -1269,7 +1270,7 @@ class _ModelFnWrapper(object):
def _call_model_fn(self, features, labels, is_export_mode=False):
"""Calls the model_fn with required parameters."""
- model_fn_args = util.fn_args(self._model_fn)
+ model_fn_args = function_utils.fn_args(self._model_fn)
kwargs = {}
# Makes deep copy with `config` and params` in case user mutates them.
@@ -1361,7 +1362,7 @@ class _OutfeedHostCall(object):
if isinstance(host_call[1], (tuple, list)):
fullargspec = tf_inspect.getfullargspec(host_call[0])
- fn_args = util.fn_args(host_call[0])
+ fn_args = function_utils.fn_args(host_call[0])
# wrapped_hostcall_with_global_step uses varargs, so we allow that.
if fullargspec.varargs is None and len(host_call[1]) != len(fn_args):
raise RuntimeError(
@@ -1612,7 +1613,9 @@ class TPUEstimator(estimator_lib.Estimator):
==========
`model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics`
- for TPU evaluation.
+ for TPU evaluation. However, if eval_on_tpu is False, `model_fn` must return
+ `EstimatorSpec` and the evaluation will execute on CPU or GPU; in this case
+ the following discussion on TPU evaluation does not apply.
`TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where
`tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See
@@ -1759,7 +1762,9 @@ class TPUEstimator(estimator_lib.Estimator):
train_batch_size=None,
eval_batch_size=None,
predict_batch_size=None,
- batch_axis=None):
+ batch_axis=None,
+ eval_on_tpu=True,
+ warm_start_from=None):
"""Constructs an `TPUEstimator` instance.
Args:
@@ -1777,7 +1782,8 @@ class TPUEstimator(estimator_lib.Estimator):
basic python types. There are reserved keys for `TPUEstimator`,
including 'batch_size'.
use_tpu: A bool indicating whether TPU support is enabled. Currently,
- - TPU training and evaluation respect this bit.
+ - TPU training and evaluation respect this bit, but eval_on_tpu can
+ override execution of eval. See below.
- Predict still happens on CPU.
train_batch_size: An int representing the global training batch size.
TPUEstimator transforms this global batch size to a per-shard batch
@@ -1798,6 +1804,14 @@ class TPUEstimator(estimator_lib.Estimator):
and per_host_input_for_training is True, batches will be sharded based
on the major dimension. If tpu_config.per_host_input_for_training is
False or `PER_HOST_V2`, batch_axis is ignored.
+ eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the
+ model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`.
+ warm_start_from: Optional string filepath to a checkpoint or SavedModel to
+ warm-start from, or a `tf.estimator.WarmStartSettings`
+ object to fully configure warm-starting. If the string
+ filepath is provided instead of a `WarmStartSettings`,
+ then all variables are warm-started, and it is assumed
+ that vocabularies and Tensor names are unchanged.
Raises:
ValueError: `params` has reserved keys already.
@@ -1850,7 +1864,8 @@ class TPUEstimator(estimator_lib.Estimator):
model_fn=model_function,
model_dir=model_dir,
config=config,
- params=params)
+ params=params,
+ warm_start_from=warm_start_from)
self._iterations_per_training_loop = (
self._config.tpu_config.iterations_per_loop)
@@ -1859,7 +1874,8 @@ class TPUEstimator(estimator_lib.Estimator):
self._ctx = tpu_context._get_tpu_context(
self._config, train_batch_size,
eval_batch_size, predict_batch_size,
- use_tpu)
+ use_tpu,
+ eval_on_tpu)
self._is_input_fn_invoked = None
@@ -1930,7 +1946,7 @@ class TPUEstimator(estimator_lib.Estimator):
Raises:
ValueError: if input_fn takes invalid arguments or does not have `params`.
"""
- input_fn_args = util.fn_args(input_fn)
+ input_fn_args = function_utils.fn_args(input_fn)
config = self.config # a deep copy.
kwargs = {}
if 'params' in input_fn_args:
diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py
index f0418f04ba..3beb7bfe30 100644
--- a/tensorflow/contrib/training/python/training/hparam.py
+++ b/tensorflow/contrib/training/python/training/hparam.py
@@ -34,7 +34,7 @@ from tensorflow.python.util import deprecation
# where <rhs> is either a single token or [] enclosed list of tokens.
# For example: "var[1] = a" or "x = [1,2,3]"
PARAM_RE = re.compile(r"""
- (?P<name>[a-zA-Z][\w]*) # variable name: "var" or "x"
+ (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
(\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
\s*=\s*
((?P<val>[^,\[]*) # single value: "a" or None
@@ -200,6 +200,13 @@ def parse_values(values, type_map):
If a hyperparameter name in both an index assignment and scalar assignment,
a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
+ The hyperparameter name may contain '.' symbols, which will result in an
+ attribute name that is only accessible through the getattr and setattr
+ functions. (And must be first explicit added through add_hparam.)
+
+ WARNING: Use of '.' in your variable names is allowed, but is not well
+ supported and not recommended.
+
The `value` in `name=value` must follows the syntax according to the
type of the parameter:
diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py
index 11fd15b527..660c97f25e 100644
--- a/tensorflow/contrib/training/python/training/hparam_test.py
+++ b/tensorflow/contrib/training/python/training/hparam_test.py
@@ -118,6 +118,21 @@ class HParamsTest(test.TestCase):
self.assertEqual('2.3"', hparams2.c_c)
self.assertEqual('/a=b/c/d', hparams2.d)
+ def testWithPeriodInVariableName(self):
+ hparams = hparam.HParams()
+ hparams.add_hparam(name='a.b', value=0.0)
+ hparams.parse('a.b=1.0')
+ self.assertEqual(1.0, getattr(hparams, 'a.b'))
+ hparams.add_hparam(name='c.d', value=0.0)
+ with self.assertRaisesRegexp(ValueError, 'Could not parse'):
+ hparams.parse('c.d=abc')
+ hparams.add_hparam(name='e.f', value='')
+ hparams.parse('e.f=abc')
+ self.assertEqual('abc', getattr(hparams, 'e.f'))
+ hparams.add_hparam(name='d..', value=0.0)
+ hparams.parse('d..=10.0')
+ self.assertEqual(10.0, getattr(hparams, 'd..'))
+
def testSetFromMap(self):
hparams = hparam.HParams(a=1, b=2.0, c='tanh')
hparams.override_from_dict({'a': -2, 'c': 'identity'})
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 4b86d6ef47..33eea0a421 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -225,6 +225,7 @@ ADDITIONAL_CORE_PROTO_SRCS = [
"protobuf/named_tensor.proto",
"protobuf/saved_model.proto",
"protobuf/tensorflow_server.proto",
+ "protobuf/transport_options.proto",
"util/test_log.proto",
]
@@ -303,6 +304,7 @@ PLATFORM_OTHER_HDRS = [
"platform/cpu_info.h",
"platform/cpu_feature_guard.h",
"platform/dynamic_annotations.h",
+ "platform/error.h",
"platform/env.h",
"platform/file_system.h",
"platform/file_system_helper.h",
@@ -1519,6 +1521,13 @@ tf_pyclif_proto_library(
)
tf_pyclif_proto_library(
+ name = "framework/cost_graph_pyclif",
+ proto_lib = ":protos_all_cc",
+ proto_srcfile = "framework/cost_graph.proto",
+ visibility = ["//visibility:public"],
+)
+
+tf_pyclif_proto_library(
name = "framework/tensor_pyclif",
proto_lib = ":protos_all_cc",
proto_srcfile = "framework/tensor.proto",
@@ -2359,6 +2368,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/executor.h",
"common_runtime/graph_optimizer.h",
"common_runtime/local_device.h",
+ "common_runtime/lower_if_op.h",
"common_runtime/memory_types.h",
"common_runtime/mkl_cpu_allocator.h",
"common_runtime/optimization_registry.h",
@@ -2409,6 +2419,7 @@ tf_cuda_library(
"common_runtime/graph_optimizer.cc",
"common_runtime/graph_runner.cc",
"common_runtime/local_device.cc",
+ "common_runtime/lower_if_op.cc",
"common_runtime/memory_types.cc",
"common_runtime/mkl_cpu_allocator.cc",
"common_runtime/optimization_registry.cc",
@@ -2566,6 +2577,7 @@ tf_cuda_library(
],
copts = tf_copts(),
cuda_deps = tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps(),
+ visibility = ["//visibility:private"],
deps = [
":core_cpu_internal",
":lib",
@@ -3402,7 +3414,11 @@ tf_cuda_only_cc_test(
":test",
":test_main",
"//third_party/eigen3",
- ],
+ ] + if_mkl(
+ [
+ "//third_party/mkl:intel_binary_blob",
+ ],
+ ),
)
tf_cc_test_gpu(
@@ -4068,6 +4084,29 @@ tf_cc_test_gpu(
],
)
+tf_cc_tests(
+ name = "common_runtime_lower_if_op_test",
+ size = "small",
+ srcs = ["common_runtime/lower_if_op_test.cc"],
+ deps = [
+ ":all_kernels",
+ ":core_cpu",
+ ":core_cpu_internal",
+ ":direct_session",
+ ":framework",
+ ":framework_internal",
+ ":lib",
+ ":test",
+ ":test_main",
+ ":testlib",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/cc:client_session",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:ops",
+ ],
+)
+
# Test data
filegroup(
name = "image_testdata",
diff --git a/tensorflow/core/api_def/base_api/api_def_CropAndResize.pbtxt b/tensorflow/core/api_def/base_api/api_def_CropAndResize.pbtxt
index 629f575d0a..e6609a16e1 100644
--- a/tensorflow/core/api_def/base_api/api_def_CropAndResize.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_CropAndResize.pbtxt
@@ -47,8 +47,9 @@ END
attr {
name: "method"
description: <<END
-A string specifying the interpolation method. Only 'bilinear' is
-supported for now.
+A string specifying the sampling method for resizing. It can be either
+`"bilinear"` or `"nearest"` and default to `"bilinear"`. Currently two sampling
+methods are supported: Bilinear and Nearest Neighbor.
END
}
attr {
@@ -57,18 +58,22 @@ END
Value used for extrapolation, when applicable.
END
}
- summary: "Extracts crops from the input image tensor and bilinearly resizes them (possibly"
+ summary: "Extracts crops from the input image tensor and resizes them."
description: <<END
-with aspect ratio change) to a common output size specified by `crop_size`. This
-is more general than the `crop_to_bounding_box` op which extracts a fixed size
-slice from the input image and does not allow resizing or aspect ratio change.
+Extracts crops from the input image tensor and resizes them using bilinear
+sampling or nearest neighbor sampling (possibly with aspect ratio change) to a
+common output size specified by `crop_size`. This is more general than the
+`crop_to_bounding_box` op which extracts a fixed size slice from the input image
+and does not allow resizing or aspect ratio change.
Returns a tensor with `crops` from the input `image` at positions defined at the
bounding box locations in `boxes`. The cropped boxes are all resized (with
-bilinear interpolation) to a fixed `size = [crop_height, crop_width]`. The
-result is a 4-D tensor `[num_boxes, crop_height, crop_width, depth]`. The
-resizing is corner aligned. In particular, if `boxes = [[0, 0, 1, 1]]`, the
-method will give identical results to using `tf.image.resize_bilinear()`
-with `align_corners=True`.
+bilinear or nearest neighbor interpolation) to a fixed
+`size = [crop_height, crop_width]`. The result is a 4-D tensor
+`[num_boxes, crop_height, crop_width, depth]`. The resizing is corner aligned.
+In particular, if `boxes = [[0, 0, 1, 1]]`, the method will give identical
+results to using `tf.image.resize_bilinear()` or
+`tf.image.resize_nearest_neighbor()`(depends on the `method` argument) with
+`align_corners=True`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV3.pbtxt b/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV3.pbtxt
new file mode 100644
index 0000000000..25ec87eeca
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV3.pbtxt
@@ -0,0 +1,64 @@
+op {
+ graph_op_name: "NonMaxSuppressionV3"
+ in_arg {
+ name: "boxes"
+ description: <<END
+A 2-D float tensor of shape `[num_boxes, 4]`.
+END
+ }
+ in_arg {
+ name: "scores"
+ description: <<END
+A 1-D float tensor of shape `[num_boxes]` representing a single
+score corresponding to each box (each row of boxes).
+END
+ }
+ in_arg {
+ name: "max_output_size"
+ description: <<END
+A scalar integer tensor representing the maximum number of
+boxes to be selected by non max suppression.
+END
+ }
+ in_arg {
+ name: "iou_threshold"
+ description: <<END
+A 0-D float tensor representing the threshold for deciding whether
+boxes overlap too much with respect to IOU.
+END
+ }
+ in_arg {
+ name: "score_threshold"
+ description: <<END
+A 0-D float tensor representing the threshold for deciding when to remove
+boxes based on score.
+END
+ }
+ out_arg {
+ name: "selected_indices"
+ description: <<END
+A 1-D integer tensor of shape `[M]` representing the selected
+indices from the boxes tensor, where `M <= max_output_size`.
+END
+ }
+ summary: "Greedily selects a subset of bounding boxes in descending order of score,"
+ description: <<END
+pruning away boxes that have high intersection-over-union (IOU) overlap
+with previously selected boxes. Bounding boxes with score less than
+`score_threshold` are removed. Bounding boxes are supplied as
+[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
+diagonal pair of box corners and the coordinates can be provided as normalized
+(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
+is agnostic to where the origin is in the coordinate system and more
+generally is invariant to orthogonal transformations and translations
+of the coordinate system; thus translating or reflections of the coordinate
+system result in the same boxes being selected by the algorithm.
+The output of this operation is a set of integers indexing into the input
+collection of bounding boxes representing the selected boxes. The bounding
+box coordinates corresponding to the selected indices can then be obtained
+using the `tf.gather operation`. For example:
+ selected_indices = tf.image.non_max_suppression_v2(
+ boxes, scores, max_output_size, iou_threshold, score_threshold)
+ selected_boxes = tf.gather(boxes, selected_indices)
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
new file mode 100644
index 0000000000..8cef243aee
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
@@ -0,0 +1,30 @@
+op {
+ graph_op_name: "RegexFullMatch"
+ in_arg {
+ name: "input"
+ description: <<END
+A string tensor of the text to be processed.
+END
+ }
+ in_arg {
+ name: "pattern"
+ description: <<END
+A 1-D string tensor of the regular expression to match the input.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A bool tensor with the same shape as `input`.
+END
+ }
+ summary: "Check if the input matches the regex pattern."
+ description: <<END
+The input is a string tensor of any shape. The pattern is a scalar
+string tensor which is applied to every element of the input tensor.
+The boolean values (True or False) of the output tensor indicate
+if the input matches the regex pattern provided.
+
+The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV3.pbtxt b/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV3.pbtxt
new file mode 100644
index 0000000000..263cba14eb
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV3.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "NonMaxSuppressionV3"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_RegexFullMatch.pbtxt b/tensorflow/core/api_def/python_api/api_def_RegexFullMatch.pbtxt
new file mode 100644
index 0000000000..ec310c8aeb
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_RegexFullMatch.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "RegexFullMatch"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/broadcaster.cc b/tensorflow/core/common_runtime/broadcaster.cc
index e42d3f6b92..9646a0856e 100644
--- a/tensorflow/core/common_runtime/broadcaster.cc
+++ b/tensorflow/core/common_runtime/broadcaster.cc
@@ -134,7 +134,7 @@ void Broadcaster::TreeSendTo(const CollectiveParams& cp,
// Execute a tree broadcast, i.e. each non-source device receives from
// one other and sends to up-to two others.
void Broadcaster::RunTree() {
- mutex mu;
+ mutex mu; // also guards status_ while callbacks are pending
int pending_count = 0; // GUARDED_BY(mu)
condition_variable all_done;
std::vector<int> send_to_ranks;
@@ -162,15 +162,13 @@ void Broadcaster::RunTree() {
++pending_count;
}
DispatchSend(
- target_rank, output_,
+ target_rank, (is_source_ ? &ctx_->input(0) : output_),
[this, target_rank, &mu, &pending_count, &all_done](const Status& s) {
+ mutex_lock l(mu);
status_.Update(s);
- {
- mutex_lock l(mu);
- --pending_count;
- if (pending_count == 0) {
- all_done.notify_all();
- }
+ --pending_count;
+ if (pending_count == 0) {
+ all_done.notify_all();
}
});
}
@@ -191,13 +189,11 @@ void Broadcaster::RunTree() {
op_dev_ctx, op_dev_ctx, device_, device_, ctx_->input_alloc_attr(0),
ctx_->output_alloc_attr(0), input, output_,
[this, &mu, &pending_count, &all_done](const Status& s) {
+ mutex_lock l(mu);
status_.Update(s);
- {
- mutex_lock l(mu);
- --pending_count;
- if (0 == pending_count) {
- all_done.notify_all();
- }
+ --pending_count;
+ if (0 == pending_count) {
+ all_done.notify_all();
}
});
}
diff --git a/tensorflow/core/common_runtime/broadcaster_test.cc b/tensorflow/core/common_runtime/broadcaster_test.cc
index 89d39144b3..959b93d56e 100644
--- a/tensorflow/core/common_runtime/broadcaster_test.cc
+++ b/tensorflow/core/common_runtime/broadcaster_test.cc
@@ -314,11 +314,11 @@ class BroadcasterTest : public ::testing::Test {
typedef std::function<void(Tensor*)> InitFunc;
- void Broadcast() {
+ void Broadcast(bool forward_input) {
std::atomic<int> done(0);
for (auto di : instances_) {
- SchedClosure([di, &done] {
- di->DoBroadcast();
+ SchedClosure([di, forward_input, &done] {
+ di->DoBroadcast(forward_input);
++done;
});
}
@@ -380,7 +380,8 @@ class BroadcasterTest : public ::testing::Test {
template <typename T>
void RunTest(DataType dtype, const DeviceType& device_type, int num_workers,
- int num_devices, int tensor_len, int fail_after) {
+ int num_devices, int tensor_len, int fail_after,
+ bool forward_input) {
Init(num_workers, num_devices, dtype, device_type, fail_after);
// Initialize each instance tensor with distinct values.
@@ -423,7 +424,7 @@ class BroadcasterTest : public ::testing::Test {
expected[i] = t->flat<T>()(i);
}
- Broadcast();
+ Broadcast(forward_input);
// At this point all of the ops have terminated.
for (int di = 0; di < instances_.size(); ++di) {
@@ -573,7 +574,7 @@ class BroadcasterTest : public ::testing::Test {
}
}
- void DoBroadcast() {
+ void DoBroadcast(bool forward_input) {
// Prepare an OpKernelContext.
OpKernelContext::Params op_params;
op_params.step_id = parent_->step_id_;
@@ -596,7 +597,8 @@ class BroadcasterTest : public ::testing::Test {
input_dc.push_back(dev_ctx);
op_params.input_device_contexts = &input_dc;
op_params.op_device_context = dev_ctx;
- int forward_from[] = {0};
+ int forward_from[] = {OpKernelContext::Params::kNeverForward};
+ if (forward_input) forward_from[0] = 0;
if (col_params_.is_source) {
op_params.forward_from_array = &forward_from[0];
}
@@ -680,61 +682,61 @@ class BroadcasterTest : public ::testing::Test {
// D = number of devices per worker
// L = tensor length
// A = abort after count
-#define DEF_TEST(B, T, W, D, L, A) \
- TEST_F(BroadcasterTest, \
- DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Len##L##_Abt##A) { \
- DataType dtype = DT_##B; \
- switch (dtype) { \
- case DT_FLOAT: { \
- RunTest<float>(dtype, DEVICE_##T, W, D, L, A); \
- } break; \
- case DT_DOUBLE: { \
- RunTest<double>(dtype, DEVICE_##T, W, D, L, A); \
- } break; \
- case DT_INT32: { \
- RunTest<int32>(dtype, DEVICE_##T, W, D, L, A); \
- } break; \
- case DT_INT64: { \
- RunTest<int64>(dtype, DEVICE_##T, W, D, L, A); \
- } break; \
- default: \
- LOG(FATAL) << "Unimplemented"; \
- } \
+#define DEF_TEST(B, T, W, D, L, A, F) \
+ TEST_F(BroadcasterTest, \
+ DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Len##L##_Abt##A##_Fw##F) { \
+ DataType dtype = DT_##B; \
+ switch (dtype) { \
+ case DT_FLOAT: { \
+ RunTest<float>(dtype, DEVICE_##T, W, D, L, A, F); \
+ } break; \
+ case DT_DOUBLE: { \
+ RunTest<double>(dtype, DEVICE_##T, W, D, L, A, F); \
+ } break; \
+ case DT_INT32: { \
+ RunTest<int32>(dtype, DEVICE_##T, W, D, L, A, F); \
+ } break; \
+ case DT_INT64: { \
+ RunTest<int64>(dtype, DEVICE_##T, W, D, L, A, F); \
+ } break; \
+ default: \
+ LOG(FATAL) << "Unimplemented"; \
+ } \
}
#ifndef GOOGLE_CUDA
-// B T W D L A
-DEF_TEST(FLOAT, CPU, 1, 2, 1, 0)
-DEF_TEST(FLOAT, CPU, 1, 2, 1001, 0)
-DEF_TEST(FLOAT, CPU, 2, 1, 128, 0)
-DEF_TEST(FLOAT, CPU, 2, 4, 128, 0)
-DEF_TEST(FLOAT, CPU, 2, 8, 4095, 0)
-DEF_TEST(FLOAT, CPU, 4, 4, 1045991, 0)
-
-DEF_TEST(DOUBLE, CPU, 2, 4, 128, 0)
-DEF_TEST(INT32, CPU, 2, 4, 128, 0)
-DEF_TEST(INT64, CPU, 2, 4, 128, 0)
+// B T W D L A F
+DEF_TEST(FLOAT, CPU, 1, 2, 1, 0, false)
+DEF_TEST(FLOAT, CPU, 1, 2, 1001, 0, true)
+DEF_TEST(FLOAT, CPU, 2, 1, 128, 0, false)
+DEF_TEST(FLOAT, CPU, 2, 4, 128, 0, true)
+DEF_TEST(FLOAT, CPU, 2, 8, 4095, 0, false)
+DEF_TEST(FLOAT, CPU, 4, 4, 1045991, 0, true)
+
+DEF_TEST(DOUBLE, CPU, 2, 4, 128, 0, false)
+DEF_TEST(INT32, CPU, 2, 4, 128, 0, true)
+DEF_TEST(INT64, CPU, 2, 4, 128, 0, false)
// Failure cases
-DEF_TEST(FLOAT, CPU, 2, 4, 128, 1)
-DEF_TEST(FLOAT, CPU, 2, 4, 128, 5)
+DEF_TEST(FLOAT, CPU, 2, 4, 128, 1, true)
+DEF_TEST(FLOAT, CPU, 2, 4, 128, 5, false)
#endif
#ifdef GOOGLE_CUDA
// Can only set W=1 for GPU tests.
-// B T W D L A
-DEF_TEST(FLOAT, GPU, 1, 2, 1, 0)
-DEF_TEST(FLOAT, GPU, 1, 2, 33, 0)
-DEF_TEST(FLOAT, GPU, 1, 3, 64, 0)
-DEF_TEST(FLOAT, GPU, 1, 8, 1001, 0)
-DEF_TEST(FLOAT, GPU, 1, 8, 4095, 0)
-DEF_TEST(FLOAT, GPU, 1, 8, 1045991, 0)
+// B T W D L A F
+DEF_TEST(FLOAT, GPU, 1, 2, 1, 0, true)
+DEF_TEST(FLOAT, GPU, 1, 2, 33, 0, false)
+DEF_TEST(FLOAT, GPU, 1, 3, 64, 0, true)
+DEF_TEST(FLOAT, GPU, 1, 8, 1001, 0, false)
+DEF_TEST(FLOAT, GPU, 1, 8, 4095, 0, true)
+DEF_TEST(FLOAT, GPU, 1, 8, 1045991, 0, false)
-DEF_TEST(DOUBLE, GPU, 1, 8, 1001, 0)
-DEF_TEST(INT64, GPU, 1, 8, 1001, 0)
+DEF_TEST(DOUBLE, GPU, 1, 8, 1001, 0, true)
+DEF_TEST(INT64, GPU, 1, 8, 1001, 0, false)
// Failure cases
-DEF_TEST(FLOAT, GPU, 1, 8, 128, 6)
+DEF_TEST(FLOAT, GPU, 1, 8, 128, 6, true)
#endif
} // namespace
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index bdddf927d8..1178f8624c 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -34,7 +34,8 @@ void CollectiveParamResolverLocal::CompleteGroupAsync(
void CollectiveParamResolverLocal::CompleteGroupLocal(
const string& device, CollectiveParams* cp, const GroupRecCallback& done) {
- VLOG(1) << "CompleteGroupLocal " << cp << ": " << cp->ToString();
+ VLOG(1) << "CompleteGroupLocal device=" << device << " cp: " << cp << ": "
+ << cp->ToString();
std::vector<StatusCallback> to_be_called;
GroupRec* gr = nullptr;
{
@@ -434,8 +435,9 @@ void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
}
}
-Status CollectiveParamResolverLocal::InitInstanceSharedParams(
- const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) {
+void CollectiveParamResolverLocal::InitInstanceSharedParams(
+ const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
+ const StatusCallback& done) {
VLOG(1) << "InitInstanceSharedParams " << ir;
ir->shared.instance = cp->instance;
{
@@ -461,19 +463,19 @@ Status CollectiveParamResolverLocal::InitInstanceSharedParams(
// called by a derived class, some of the devices may be non-local and
// GetDeviceLocalitiesAsync will use those fields to launch RPCs.
CompleteTaskIsLocal(task_name_, &ir->shared);
- std::vector<DeviceLocality> localities;
- Notification note;
- Status status;
- dev_resolver_->GetDeviceLocalitiesAsync(ir->shared.instance, &localities,
- [&note, &status](const Status& s) {
- status = s;
- note.Notify();
- });
- note.WaitForNotification();
- if (status.ok()) {
- CompleteDefaultRanking(gr, cp, ir, localities);
- }
- return status;
+ std::vector<DeviceLocality>* localities = new std::vector<DeviceLocality>;
+ dev_resolver_->GetDeviceLocalitiesAsync(
+ ir->shared.instance, localities,
+ [this, gr, cp, ir, localities, done](const Status& s)
+ EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu) {
+ if (s.ok()) {
+ CompleteDefaultRanking(gr, cp, ir, *localities);
+ done(Status::OK());
+ } else {
+ done(s);
+ }
+ delete localities;
+ });
}
void CollectiveParamResolverLocal::CompleteDefaultRanking(
@@ -548,28 +550,50 @@ void CollectiveParamResolverLocal::FindInstanceRec(
CallbackWithStatus(done, irec);
return;
}
- // Initialize the new InstanceRec while holding out_mu.
- {
- mutex_lock il(irec->out_mu);
- irec->known.resize(cp->group.group_size, false);
- irec->status = InitInstanceSharedParams(gr, cp, irec);
- }
- // Prepare to invoke any waiters that accumlated during initialization.
- std::vector<IRConsumer> init_waiters;
- {
- mutex_lock tl(instance_mu_);
- {
- mutex_lock l(irec->in_mu);
- irec->is_init = true;
- if (!irec->init_waiters.empty()) {
- std::swap(init_waiters, irec->init_waiters);
- }
- }
- }
- CallbackWithStatus(done, irec);
- for (auto& f : init_waiters) {
- f(irec);
- }
+
+ CallInitInstanceSharedParams(gr, cp, irec, done);
+}
+
+void CollectiveParamResolverLocal::CallInitInstanceSharedParams(
+ const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
+ const InstanceRecCallback& done) NO_THREAD_SAFETY_ANALYSIS {
+ // This function serves merely to make a function call that should
+ // be thread/mutex safe but violates the simple model applied by
+ // static analysis, so we turn off analysis only within this
+ // function body.
+ //
+ // A lock on ir->out_mu must be held throughout the _bodies_ of the
+ // chain of function calls initiated here, each of which calls
+ // another as its last action, but it will be dropped within the
+ // callback defined below, which means that the lock can be dropped
+ // before all the function stack frames pop. The static analysis will
+ // not allow that.
+ ir->out_mu.lock();
+ ir->known.resize(cp->group.group_size, false);
+ InitInstanceSharedParams(
+ gr, cp, ir,
+ [this, ir, done](const Status& s) UNLOCK_FUNCTION(ir->out_mu) {
+ DCHECK(!ir->out_mu.try_lock());
+ ir->status.Update(s);
+ ir->out_mu.unlock();
+ // Prepare to invoke any waiters that accumlated during
+ // initialization.
+ std::vector<IRConsumer> init_waiters;
+ {
+ mutex_lock tl(instance_mu_);
+ {
+ mutex_lock l(ir->in_mu);
+ ir->is_init = true;
+ if (!ir->init_waiters.empty()) {
+ std::swap(init_waiters, ir->init_waiters);
+ }
+ }
+ }
+ CallbackWithStatus(done, ir);
+ for (auto& f : init_waiters) {
+ f(ir);
+ }
+ });
}
void CollectiveParamResolverLocal::CompleteParamsAsync(
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 7b2946e936..3a871f962d 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -145,10 +145,15 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
//
// Preconditions:
// cp is populated with all DeviceLocalities
- Status InitInstanceSharedParams(const GroupRec* gr,
- const CollectiveParams* cp, InstanceRec* ir)
+ void InitInstanceSharedParams(const GroupRec* gr, const CollectiveParams* cp,
+ InstanceRec* ir, const StatusCallback& done)
EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu) LOCKS_EXCLUDED(gr->mu);
+ void CallInitInstanceSharedParams(const GroupRec* gr,
+ const CollectiveParams* cp, InstanceRec* ir,
+ const InstanceRecCallback& done)
+ LOCKS_EXCLUDED(ir->out_mu, gr->mu);
+
// Establishes the final order of ir->shared.instance.device_names and
// ir->shared.instance.task_names by considering localities of all devices.
void CompleteDefaultRanking(const GroupRec* gr, const CollectiveParams* cp,
diff --git a/tensorflow/core/common_runtime/collective_rma_local.cc b/tensorflow/core/common_runtime/collective_rma_local.cc
index ad9b32ce35..69f1a9f24c 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local.cc
@@ -54,9 +54,13 @@ void CollectiveRemoteAccessLocal::RecvFromPeer(
hook->prod_value, // src Tensor*
to_tensor, // dst Tensor*
[hook, done](const Status& s) {
+ // This callback may be executing in the GPUEventMgr
+ // pool in which case it must be very short duration
+ // and non-blocking (except e.g. for queue insertion).
+ // It would be safer, though expensive, to transfer
+ // to another thread here.
done(s);
- hook->prod_cb(s);
- delete hook;
+ BufRendezvous::DoneWithHook(hook);
});
}
});
@@ -91,6 +95,21 @@ void CollectiveRemoteAccessLocal::MemCpyAsync(
dst_attr.on_host() ? DEVICE_CPU : dst_dev->attributes().device_type());
const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU);
const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU);
+ // For GPU devices when only one compute stream is used (the default)
+ // the OpKernelContext does not supply a DeviceContext. It's assumed
+ // that all nodes use the default context.
+ if (src_dev_ctx == nullptr && src_device_type == DEVICE_GPU) {
+ const DeviceBase::GpuDeviceInfo* dev_info =
+ src_dev->tensorflow_gpu_device_info();
+ CHECK(dev_info);
+ src_dev_ctx = dev_info->default_context;
+ }
+ if (dst_dev_ctx == nullptr && dst_device_type == DEVICE_GPU) {
+ const DeviceBase::GpuDeviceInfo* dev_info =
+ src_dev->tensorflow_gpu_device_info();
+ CHECK(dev_info);
+ dst_dev_ctx = dev_info->default_context;
+ }
if (non_cpu_src) CHECK(src_dev_ctx);
if (non_cpu_dst) CHECK(dst_dev_ctx);
if (non_cpu_src || non_cpu_dst) {
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
index 5918cd9bbf..b537666492 100644
--- a/tensorflow/core/common_runtime/device.h
+++ b/tensorflow/core/common_runtime/device.h
@@ -51,6 +51,8 @@ limitations under the License.
namespace tensorflow {
+class DeviceMgr;
+
class Device : public DeviceBase {
public:
Device(Env* env, const DeviceAttributes& device_attributes);
@@ -133,6 +135,10 @@ class Device : public DeviceBase {
// Returns the resource manager associated w/ this device.
virtual ResourceMgr* resource_manager() { return rmgr_; }
+ // Returns the device manager that owns this device, or nullptr if this Device
+ // is not owned by a device manager.
+ DeviceMgr* device_mgr() const { return device_mgr_; }
+
// Summarizes the status of this Device, for debugging.
string DebugString() const { return ProtoDebugString(device_attributes_); }
@@ -158,6 +164,11 @@ class Device : public DeviceBase {
}
private:
+ friend class DeviceMgr;
+
+ // Pointer to the device manager that owns this device. Not owned.
+ DeviceMgr* device_mgr_ = nullptr;
+
const DeviceAttributes device_attributes_;
DeviceNameUtils::ParsedName parsed_name_;
diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc
index a77601ba79..470abc1431 100644
--- a/tensorflow/core/common_runtime/device_mgr.cc
+++ b/tensorflow/core/common_runtime/device_mgr.cc
@@ -27,6 +27,9 @@ namespace tensorflow {
DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
: name_backing_store_(128) {
for (Device* d : devices) {
+ CHECK(d->device_mgr_ == nullptr);
+ d->device_mgr_ = this;
+
devices_.push_back(d);
// Register under the (1) full name and (2) canonical name.
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index bf05f6f1d9..d05564e9c4 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -208,19 +208,19 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
// The instantiated and transformed function is encoded as a Graph
// object, and an executor is created for the graph.
- struct Item : public core::RefCounted {
- bool invalidated = false;
+ struct Item {
+ uint64 instantiation_counter = 0;
const Graph* graph = nullptr; // Owned by exec.
const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned.
FunctionBody* func_graph = nullptr;
Executor* exec = nullptr;
- ~Item() override {
+ ~Item() {
delete this->func_graph;
delete this->exec;
}
};
- std::unordered_map<Handle, Item*> items_ GUARDED_BY(mu_);
+ std::unordered_map<Handle, std::unique_ptr<Item>> items_ GUARDED_BY(mu_);
ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned.
@@ -284,9 +284,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
}
}
-FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
- for (auto p : items_) p.second->Unref();
-}
+FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {}
// An asynchronous op kernel which executes an instantiated function
// defined in a library.
@@ -490,30 +488,24 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
options_copy.target = device_name_;
const string key = Canonicalize(function_name, attrs, options_copy);
- Handle found_handle = kInvalidHandle;
{
mutex_lock l(mu_);
- found_handle = parent_->GetHandle(key);
- if (found_handle != kInvalidHandle) {
+ *handle = parent_->GetHandle(key);
+ if (*handle != kInvalidHandle) {
FunctionLibraryRuntime::LocalHandle handle_on_device =
- parent_->GetHandleOnDevice(device_name_, found_handle);
+ parent_->GetHandleOnDevice(device_name_, *handle);
if (handle_on_device == kInvalidLocalHandle) {
return errors::Internal("LocalHandle not found for handle ", *handle,
".");
}
- auto iter = items_.find(handle_on_device);
- if (iter == items_.end()) {
+ auto item_handle = items_.find(handle_on_device);
+ if (item_handle == items_.end()) {
return errors::Internal("LocalHandle ", handle_on_device,
- " for handle ", found_handle,
+ " for handle ", *handle,
" not found in items.");
}
- Item* item = iter->second;
- if (!item->invalidated) {
- *handle = found_handle;
- return Status::OK();
- }
- // *item is invalidated. Fall through and instantiate the given
- // function_name/attrs/option again.
+ ++item_handle->second->instantiation_counter;
+ return Status::OK();
}
}
@@ -545,16 +537,18 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
{
mutex_lock l(mu_);
- Handle found_handle_again = parent_->GetHandle(key);
- if (found_handle_again != found_handle) {
+ *handle = parent_->GetHandle(key);
+ if (*handle != kInvalidHandle) {
delete fbody;
- *handle = found_handle_again;
+ ++items_[parent_->GetHandleOnDevice(device_name_, *handle)]
+ ->instantiation_counter;
} else {
*handle = parent_->AddHandle(key, device_name_, next_handle_);
Item* item = new Item;
item->func_graph = fbody;
item->overlay_lib = options.overlay_lib;
- items_.insert({next_handle_, item});
+ item->instantiation_counter = 1;
+ items_.emplace(next_handle_, std::unique_ptr<Item>(item));
next_handle_++;
}
}
@@ -565,12 +559,17 @@ Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
return parent_->ReleaseHandle(handle);
}
+
LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
CHECK_NE(h, kInvalidLocalHandle);
mutex_lock l(mu_);
CHECK_EQ(1, items_.count(h));
- Item* item = items_[h];
- item->invalidated = true; // Reinstantiate later.
+ std::unique_ptr<Item>& item = items_[h];
+ --item->instantiation_counter;
+ if (item->instantiation_counter == 0) {
+ items_.erase(h);
+ TF_RETURN_IF_ERROR(parent_->RemoveHandle(handle));
+ }
return Status::OK();
}
@@ -680,7 +679,7 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
return errors::NotFound("Function handle ", handle,
" is not valid. Likely an internal error.");
}
- *item = items_[local_handle];
+ *item = items_[local_handle].get();
if ((*item)->exec != nullptr) {
return Status::OK();
}
@@ -731,7 +730,6 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
// computation is done and stored in *rets, we send the return values back
// to the source_device (caller) so that the ProcFLR can receive them later.
std::vector<Tensor>* remote_args = new std::vector<Tensor>;
- item->Ref();
ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
source_device, target_device, "arg_", src_incarnation, args.size(),
device_context, {}, rendezvous, remote_args,
@@ -743,7 +741,6 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
s = frame->SetArgs(*remote_args);
}
if (!s.ok()) {
- item->Unref();
delete frame;
delete remote_args;
delete exec_args;
@@ -751,10 +748,9 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
return;
}
item->exec->RunAsync(
- *exec_args, [item, frame, rets, done, source_device, target_device,
+ *exec_args, [frame, rets, done, source_device, target_device,
target_incarnation, rendezvous, device_context,
remote_args, exec_args](const Status& status) {
- core::ScopedUnref unref(item);
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
@@ -840,13 +836,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
return;
}
- item->Ref();
item->exec->RunAsync(
// Executor args
*exec_args,
// Done callback.
- [item, frame, rets, done, exec_args](const Status& status) {
- core::ScopedUnref unref(item);
+ [frame, rets, done, exec_args](const Status& status) {
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
@@ -906,7 +900,6 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
exec_args->runner = *run_opts.runner;
exec_args->call_frame = frame;
- item->Ref();
item->exec->RunAsync(
// Executor args
*exec_args,
@@ -915,7 +908,6 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
[item, frame, exec_args](DoneCallback done,
// Start unbound arguments.
const Status& status) {
- core::ScopedUnref unref(item);
delete exec_args;
done(status);
},
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 373fc64007..61b2f0e60f 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -231,8 +231,19 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return status;
}
FunctionLibraryRuntime::Options opts;
- TF_RETURN_IF_ERROR(Run(flr, handle, opts, args, rets, add_runner));
- return flr->ReleaseHandle(handle);
+ status = Run(flr, handle, opts, args, rets, add_runner);
+ if (!status.ok()) return status;
+
+ // Release the handle and try running again. It should not succeed.
+ status = flr->ReleaseHandle(handle);
+ if (!status.ok()) return status;
+
+ Status status2 = Run(flr, handle, opts, args, std::move(rets));
+ EXPECT_TRUE(errors::IsInvalidArgument(status2));
+ EXPECT_TRUE(
+ str_util::StrContains(status2.error_message(), "remote execution."));
+
+ return status;
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
@@ -293,8 +304,16 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
*rets[i] = retvals[i];
}
- // Release the handle.
- return flr->ReleaseHandle(handle);
+ // Release the handle and try running again. It should not succeed.
+ status = flr->ReleaseHandle(handle);
+ if (!status.ok()) return status;
+
+ Status status2 = Run(flr, handle, opts, args, std::move(rets));
+ EXPECT_TRUE(errors::IsInvalidArgument(status2));
+ EXPECT_TRUE(
+ str_util::StrContains(status2.error_message(), "remote execution."));
+
+ return status;
}
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
diff --git a/tensorflow/core/common_runtime/function_threadpool_test.cc b/tensorflow/core/common_runtime/function_threadpool_test.cc
index 98dac38a8c..2d09e83d01 100644
--- a/tensorflow/core/common_runtime/function_threadpool_test.cc
+++ b/tensorflow/core/common_runtime/function_threadpool_test.cc
@@ -144,7 +144,19 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return status;
}
FunctionLibraryRuntime::Options opts;
- return Run(flr, handle, opts, args, std::move(rets), add_runner);
+ status = Run(flr, handle, opts, args, rets, add_runner);
+ if (!status.ok()) return status;
+
+ // Release the handle and try running again. It should not succeed.
+ status = flr->ReleaseHandle(handle);
+ if (!status.ok()) return status;
+
+ Status status2 = Run(flr, handle, opts, args, std::move(rets));
+ EXPECT_TRUE(errors::IsInvalidArgument(status2));
+ EXPECT_TRUE(
+ str_util::StrContains(status2.error_message(), "remote execution."));
+
+ return status;
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 9b434e5e2f..c84fe48084 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -104,8 +104,9 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
reinterpret_cast<unsigned int*>(scratch + Eigen::kCudaScratchSize);
stream_ = cuda_stream;
allocator_ = alloc;
- const int cuda_gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id).value();
- device_prop_ = &Eigen::m_deviceProperties[cuda_gpu_id];
+ CudaGpuId cuda_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ device_prop_ = &Eigen::m_deviceProperties[cuda_gpu_id.value()];
}
const cudaStream_t& stream() const override { return *stream_; }
@@ -317,7 +318,9 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
gpu_device_info_->stream = streams_[0]->compute;
gpu_device_info_->default_context = device_contexts_[0];
gpu_device_info_->event_mgr = em_.get();
- gpu_device_info_->gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id_).value();
+ CudaGpuId cuda_gpu_id;
+ TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id_, &cuda_gpu_id));
+ gpu_device_info_->gpu_id = cuda_gpu_id.value();
set_tensorflow_gpu_device_info(gpu_device_info_);
// Whether and how the GPU device uses its own threadpool.
@@ -965,7 +968,8 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
while (next_tf_gpu_id < memory_limit_bytes.size()) {
TfGpuId tf_gpu_id(next_tf_gpu_id);
++next_tf_gpu_id;
- GpuIdManager::InsertTfCudaGpuIdPair(tf_gpu_id, cuda_gpu_id);
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::InsertTfCudaGpuIdPair(tf_gpu_id, cuda_gpu_id));
}
}
const int num_tf_gpus = next_tf_gpu_id;
@@ -1016,7 +1020,8 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
const string device_name =
strings::StrCat(name_prefix, "/device:GPU:", tf_gpu_id.value());
GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
- CudaGpuId cuda_gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id);
+ CudaGpuId cuda_gpu_id;
+ TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
int numa_node = dev_locality.numa_node();
se::StreamExecutor* se =
@@ -1101,7 +1106,8 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
all_tf_gpu_ids.push_back(TfGpuId(i));
}
for (TfGpuId tf_gpu_id : all_tf_gpu_ids) {
- CudaGpuId cuda_gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id);
+ CudaGpuId cuda_gpu_id;
+ TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
// Get GPU bus_id from its reported NUMA affinity. Because GPUs are
// virtualized in some environments, we can't just use the GPU id.
// NUMA locales are indexed from 0, buses are indexed from 1.
@@ -1129,7 +1135,9 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
LocalLinks* links = dev_locality.mutable_links();
for (const InterconnectMap& imap : interconnects) {
for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) {
- CudaGpuId cuda_gpu_dst = GpuIdManager::TfToCudaGpuId(tf_gpu_dst);
+ CudaGpuId cuda_gpu_dst;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToCudaGpuId(tf_gpu_dst, &cuda_gpu_dst));
if (imap.directed_links.find({cuda_gpu_id, cuda_gpu_dst}) !=
imap.directed_links.end()) {
InterconnectLink* ilink = links->add_link();
@@ -1144,7 +1152,9 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
// add high strength links to the others.
for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) {
if (tf_gpu_id == tf_gpu_dst) continue;
- CudaGpuId cuda_gpu_dst = GpuIdManager::TfToCudaGpuId(tf_gpu_dst);
+ CudaGpuId cuda_gpu_dst;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToCudaGpuId(tf_gpu_dst, &cuda_gpu_dst));
if (cuda_gpu_id == cuda_gpu_dst) {
InterconnectLink* ilink = links->add_link();
ilink->set_device_id(tf_gpu_dst.value());
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index b754ffd2db..3e958a70f1 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -90,7 +90,11 @@ class BaseGPUDevice : public LocalDevice {
// Returns the CUDA GPU id of this device within the native driver system;
// e.g., for CUDA this is the ordinal of the GPU within the system.
- int gpu_id() const { return GpuIdManager::TfToCudaGpuId(tf_gpu_id_).value(); }
+ int gpu_id() const {
+ CudaGpuId cuda_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id_, &cuda_gpu_id));
+ return cuda_gpu_id.value();
+ }
// The executor that provides control for the device; e.g., for CUDA this
// corresponds to the cuda context.
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
index f3935f6ba2..bb00173d1e 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
@@ -29,7 +29,7 @@ const char* kDeviceNamePrefix = "/job:localhost/replica:0/task:0";
class GPUDeviceTest : public ::testing::Test {
public:
- void TearDown() { ProcessState::singleton()->TestOnlyReset(); }
+ void TearDown() override { ProcessState::singleton()->TestOnlyReset(); }
protected:
static SessionOptions MakeSessionOptions(
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
index 1d4ad957b9..c5ff6c97a1 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
@@ -119,7 +119,7 @@ TEST(EventMgr, DelayedPolling) {
EXPECT_EQ(0, th.queue_size());
TensorReferenceVector* v = nullptr;
std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
- CHECK(stream.get());
+ CHECK(stream);
stream->Init();
for (int i = 0; i < 5; ++i) {
v = new TensorReferenceVector;
@@ -151,7 +151,7 @@ TEST(EventMgr, FlushLargeTensorImmediately) {
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, live_tensor_bytes);
std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
- CHECK(stream.get());
+ CHECK(stream);
stream->Init();
for (int i = 0; i < 5; ++i) {
TensorReferenceVector v;
@@ -168,7 +168,7 @@ TEST(EventMgr, ManySmallTensorsFlushedImmediately) {
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, live_tensor_bytes);
std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
- CHECK(stream.get());
+ CHECK(stream);
stream->Init();
for (int i = 0; i < 5; ++i) {
TensorReferenceVector v;
@@ -209,7 +209,7 @@ TEST(EventMgr, ManySmallTensorsSeparateCallsFlushed) {
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, live_tensor_bytes);
std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
- CHECK(stream.get());
+ CHECK(stream);
stream->Init();
for (int i = 0; i < 5; ++i) {
for (int i = 0; i < 1000; i++) {
@@ -232,7 +232,7 @@ TEST(EventMgr, NonEmptyShutdown) {
EXPECT_EQ(0, th.queue_size());
EXPECT_EQ(0, th.free_size());
std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
- CHECK(stream.get());
+ CHECK(stream);
stream->Init();
for (int i = 0; i < 5; ++i) {
TensorReferenceVector* v = new TensorReferenceVector;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
index 7dfff3269c..b5099dc8ef 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
@@ -34,46 +34,40 @@ class TfToCudaGpuIdMap {
return id_map;
}
- void InsertOrDie(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id)
- LOCKS_EXCLUDED(mu_) {
+ Status Insert(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id) LOCKS_EXCLUDED(mu_) {
std::pair<IdMapType::iterator, bool> result;
{
mutex_lock lock(mu_);
result = id_map_.insert({tf_gpu_id.value(), cuda_gpu_id.value()});
}
- if (!result.second) {
- CHECK_EQ(cuda_gpu_id.value(), result.first->second)
- << "Mapping the same TfGpuId to a different CUDA GPU id."
- << " TfGpuId: " << tf_gpu_id
- << " Existing mapped CUDA GPU id: " << result.first->second
- << " CUDA GPU id being tried to map to: " << cuda_gpu_id;
+ if (!result.second && cuda_gpu_id.value() != result.first->second) {
+ return errors::AlreadyExists(
+ "TensorFlow device (GPU:", tf_gpu_id.value(),
+ ") is being mapped to "
+ "multiple CUDA devices (",
+ cuda_gpu_id.value(), " now, and ", result.first->second,
+ " previously), which is not supported. "
+ "This may be the result of providing different GPU configurations "
+ "(ConfigProto.gpu_options, for example different visible_device_list)"
+ " when creating multiple Sessions in the same process. This is not "
+ " currently supported, see "
+ "https://github.com/tensorflow/tensorflow/issues/19083");
}
- }
-
- CudaGpuId FindOrDie(TfGpuId tf_gpu_id) const LOCKS_EXCLUDED(mu_) {
- mutex_lock lock(mu_);
- return FindOrDieLocked(tf_gpu_id);
+ return Status::OK();
}
bool Find(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) const
LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
- if (id_map_.count(tf_gpu_id.value()) == 0) return false;
- *cuda_gpu_id = FindOrDieLocked(tf_gpu_id);
+ auto result = id_map_.find(tf_gpu_id.value());
+ if (result == id_map_.end()) return false;
+ *cuda_gpu_id = result->second;
return true;
}
private:
TfToCudaGpuIdMap() = default;
- CudaGpuId FindOrDieLocked(TfGpuId tf_gpu_id) const
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- auto result = id_map_.find(tf_gpu_id.value());
- CHECK(result != id_map_.end())
- << "Could not find the mapping for TfGpuId: " << tf_gpu_id;
- return CudaGpuId(result->second);
- }
-
void TestOnlyReset() LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
id_map_.clear();
@@ -88,23 +82,19 @@ class TfToCudaGpuIdMap {
};
} // namespace
-void GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id,
- CudaGpuId cuda_gpu_id) {
- TfToCudaGpuIdMap::singleton()->InsertOrDie(tf_gpu_id, cuda_gpu_id);
+Status GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id,
+ CudaGpuId cuda_gpu_id) {
+ return TfToCudaGpuIdMap::singleton()->Insert(tf_gpu_id, cuda_gpu_id);
}
Status GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) {
if (TfToCudaGpuIdMap::singleton()->Find(tf_gpu_id, cuda_gpu_id)) {
return Status::OK();
}
- return errors::NotFound("TF GPU device with id ", tf_gpu_id.value(),
+ return errors::NotFound("TensorFlow device GPU:", tf_gpu_id.value(),
" was not registered");
}
-CudaGpuId GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id) {
- return TfToCudaGpuIdMap::singleton()->FindOrDie(tf_gpu_id);
-}
-
void GpuIdManager::TestOnlyReset() {
TfToCudaGpuIdMap::singleton()->TestOnlyReset();
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
index 2b54cc184c..491d92ccdd 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
@@ -26,13 +26,10 @@ namespace tensorflow {
class GpuIdManager {
public:
// Adds a mapping from tf_gpu_id to cuda_gpu_id.
- static void InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id);
+ static Status InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id);
// Gets the cuda_gpu_id associated with tf_gpu_id. Returns OK if found.
static Status TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id);
- // Similar to the above version, but returns the result, and checks fail if
- // no result is found.
- static CudaGpuId TfToCudaGpuId(TfGpuId tf_gpu_id);
// Clears the map. Used in unit tests only.
static void TestOnlyReset();
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
index bdbd8d065b..a663ec7051 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
@@ -16,40 +16,45 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
-namespace test {
+namespace {
+
+CudaGpuId TfToCudaGpuId(TfGpuId tf) {
+ CudaGpuId cuda;
+ TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf, &cuda));
+ return cuda;
+}
TEST(GpuIdManagerTest, Basics) {
TfGpuId key_0(0);
CudaGpuId value_0(0);
- GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0);
- EXPECT_EQ(value_0, GpuIdManager::TfToCudaGpuId(key_0));
+ TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0));
+ EXPECT_EQ(value_0, TfToCudaGpuId(key_0));
// Multiple calls to map the same value is ok.
- GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0);
- EXPECT_EQ(value_0, GpuIdManager::TfToCudaGpuId(key_0));
+ TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0));
+ EXPECT_EQ(value_0, TfToCudaGpuId(key_0));
// Map a different TfGpuId to a different value.
TfGpuId key_1(3);
CudaGpuId value_1(2);
- GpuIdManager::InsertTfCudaGpuIdPair(key_1, value_1);
- EXPECT_EQ(value_1, GpuIdManager::TfToCudaGpuId(key_1));
+ TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_1, value_1));
+ EXPECT_EQ(value_1, TfToCudaGpuId(key_1));
// Mapping a different TfGpuId to the same value is ok.
TfGpuId key_2(10);
- GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_1);
- EXPECT_EQ(value_1, GpuIdManager::TfToCudaGpuId(key_2));
+ TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_1));
+ EXPECT_EQ(value_1, TfToCudaGpuId(key_2));
- // Mapping the same TfGpuId to a different value will crash the program.
- ASSERT_DEATH(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_0),
- "Mapping the same TfGpuId to a different CUDA GPU id");
+ // Mapping the same TfGpuId to a different value.
+ ASSERT_FALSE(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_0).ok());
- // Getting an nonexistent mapping will crash the program.
- ASSERT_DEATH(GpuIdManager::TfToCudaGpuId(TfGpuId(100)),
- "Could not find the mapping for TfGpuId");
+ // Getting a nonexistent mapping.
+ ASSERT_FALSE(GpuIdManager::TfToCudaGpuId(TfGpuId(100), &value_0).ok());
}
-} // namespace test
+} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
index 42bf074e63..b9c66b3328 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
@@ -39,12 +39,15 @@ class GpuIdUtil {
}
static se::port::StatusOr<se::StreamExecutor*> ExecutorForTfGpuId(
TfGpuId tf_gpu_id) {
- return ExecutorForCudaGpuId(GpuIdManager::TfToCudaGpuId(tf_gpu_id));
+ CudaGpuId cuda_gpu_id;
+ TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ return ExecutorForCudaGpuId(cuda_gpu_id);
}
// Verify that the cuda_gpu_id associated with a TfGpuId is legitimate.
static void CheckValidTfGpuId(TfGpuId tf_gpu_id) {
- const CudaGpuId cuda_gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id);
+ CudaGpuId cuda_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
const int visible_device_count = GPUMachineManager()->VisibleDeviceCount();
CHECK_LT(cuda_gpu_id.value(), visible_device_count)
<< "cuda_gpu_id is outside discovered device range."
diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc
index 5ed01278c1..2b442071e2 100644
--- a/tensorflow/core/common_runtime/gpu/process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/process_state.cc
@@ -126,7 +126,8 @@ Allocator* ProcessState::GetGPUAllocator(const GPUOptions& options,
return nullptr;
}
- const CudaGpuId cuda_gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id);
+ CudaGpuId cuda_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
gpu_allocator =
new GPUBFCAllocator(cuda_gpu_id, total_bytes, options,
strings::StrCat("GPU_", tf_gpu_id.value(), "_bfc"));
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
new file mode 100644
index 0000000000..b5fee36ff4
--- /dev/null
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -0,0 +1,283 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/lower_if_op.h"
+
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+
+namespace tensorflow {
+
+// TODO(jpienaar): Consider making it a public attribute.
+const char* const LowerIfOpPass::kLowerUsingSwitchMergeAttr =
+ "_lower_using_switch_merge";
+
+namespace {
+
+using NodeOut = NodeBuilder::NodeOut;
+
+// Convenience builder to make it easy to construct a conditional with a single
+// function call in the then and else branch. This first converts the if node
+// into switches (for inputs) and merges (for outputs) around a function call
+// per branch, then inlines the function calls.
+class CondBuilder {
+ public:
+ enum Branch { kElseBranch = 0, kThenBranch = 1 };
+
+ // Create a CondBuilder to create the lowering of If op. that has then and
+ // else functions named `then_fn_name` and `else_fn_name` respectively in the
+ // given graph.
+ CondBuilder(Node* if_op, const string& then_fn_name,
+ const string& else_fn_name, Graph* graph);
+
+ // Constructs the basic conditional control flow using switch and merge nodes.
+ Status CreatePivotNodes();
+
+ // Adds the inputs from the if node to the merge nodes of the lowered if.
+ Status AddInputs();
+
+ // Adds the outputs from the if node to the merge nodes of the lowered if.
+ // Note: no inputs can be added once outputs are added as the then and else
+ // nodes are finalized while adding outputs.
+ Status AddOutputs();
+
+ // Builds an identity node with the same outputs as If.
+ Status BuildLoweredIfOutput();
+
+ // Inline call nodes for then and else.
+ Status InlineCallNodes();
+
+ private:
+ // Returns unique name containing the name of the If op being rewritten
+ // (name_), infix and a suffix to ensure it is unique within the graph.
+ string NewName(const string& infix);
+
+ // Adds input to both the then and else nodes from src:src_output.
+ Status AddInput(Node* src, int src_output);
+
+ // The merged outputs of the then and else nodes.
+ std::vector<NodeOut> outputs_;
+
+ // The node that dominates all execution of the then and else body nodes.
+ Node* control_predecessor_;
+ // The original If op.
+ Node* if_op_;
+ // The identity node with the same outputs as the original If op.
+ Node* lowered_if_output_;
+ // The predicate of the conditional.
+ Node* pred_;
+ // Node corresponding to pivot_f branch of predicate switch which is
+ // the pivot node that dominates all nodes in the false/else branch.
+ Node* pivot_f_;
+ // Node corresponding to pivot_t branch of predicate switch which is
+ // the pivot node that dominates all nodes in the true/then branch.
+ Node* pivot_t_;
+ Node* then_call_node_;
+ Node* else_call_node_;
+ Graph* graph_;
+ string name_;
+
+ NodeBuilder then_call_builder_;
+ NodeBuilder else_call_builder_;
+};
+
+CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
+ const string& else_fn_name, Graph* graph)
+ : if_op_(if_op),
+ graph_(graph),
+ name_(if_op->name()),
+ then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()),
+ else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) {
+ TF_CHECK_OK(if_op_->input_node(0, &pred_));
+}
+
+Status CondBuilder::CreatePivotNodes() {
+ // Construct the basic cond body (consisting of feeding in the predicate to
+ // create pivot nodes).
+ Node* switch_pred;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry())
+ .Input(NodeOut(pred_, 0))
+ .Input(NodeOut(pred_, 0))
+ .Finalize(graph_, &switch_pred));
+ control_predecessor_ = switch_pred;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("pivot_f"), "Identity", graph_->op_registry())
+ .Input(switch_pred, kElseBranch)
+ .Finalize(graph_, &pivot_f_));
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("pivot_t"), "Identity", graph_->op_registry())
+ .Input(switch_pred, kThenBranch)
+ .Finalize(graph_, &pivot_t_));
+ return Status::OK();
+}
+
+string CondBuilder::NewName(const string& infix) {
+ return graph_->NewName(strings::StrCat(name_, "/", infix));
+}
+
+Status CondBuilder::AddInput(Node* src, int src_output) {
+ Node* input;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry())
+ .Input(src, src_output)
+ .Input(pred_, 0)
+ .Finalize(graph_, &input));
+ then_call_builder_.Input(input, kThenBranch);
+ else_call_builder_.Input(input, kElseBranch);
+ return Status::OK();
+}
+
+Status CondBuilder::AddInputs() {
+ // Add input data edges.
+ std::vector<const Edge*> edges;
+ TF_RETURN_IF_ERROR(if_op_->input_edges(&edges));
+ // Start at index 1 as the first input is the predicate.
+ for (int i = 1; i < edges.size(); ++i) {
+ const Edge* e = edges[i];
+ TF_RETURN_IF_ERROR(AddInput(e->src(), e->src_output()));
+ }
+ // Add input control edges.
+ for (const Edge* e : if_op_->in_edges()) {
+ if (e->IsControlEdge()) {
+ graph_->AddControlEdge(e->src(), control_predecessor_);
+ }
+ }
+ return Status::OK();
+}
+
+Status CondBuilder::AddOutputs() {
+ // Construct the then and else nodes.
+ TF_RETURN_IF_ERROR(then_call_builder_.Finalize(graph_, &then_call_node_));
+ graph_->AddControlEdge(pivot_t_, then_call_node_);
+ TF_RETURN_IF_ERROR(else_call_builder_.Finalize(graph_, &else_call_node_));
+ graph_->AddControlEdge(pivot_f_, else_call_node_);
+
+ // Merge the outputs from the two branches.
+ std::vector<Node*> merges(then_call_node_->num_outputs());
+ outputs_.resize(merges.size());
+ for (int i = 0; i < then_call_node_->num_outputs(); ++i) {
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(graph_->NewName("merge"), "Merge", graph_->op_registry())
+ .Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)})
+ .Finalize(graph_, &merges[i]));
+ outputs_[i] = NodeOut(merges[i], 0);
+ }
+
+ TF_RETURN_IF_ERROR(BuildLoweredIfOutput());
+
+ // Add outputs.
+ for (const Edge* e : if_op_->out_edges()) {
+ if (e->IsControlEdge()) {
+ graph_->AddControlEdge(lowered_if_output_, e->dst());
+ } else {
+ // Feed the outputs directly from the merge nodes so that downstream ops
+ // can start before all the outputs have been computed.
+ graph_->AddEdge(merges[e->src_output()], e->src_output(), e->dst(),
+ e->dst_input());
+ }
+ }
+ return Status::OK();
+}
+
+Status InlineCallInGraph(Node* n, Graph* g) {
+ const auto& lib = g->flib_def();
+ const FunctionDef* fdef = lib.Find(n->type_string());
+ CHECK(fdef != nullptr);
+ FunctionBody* fbody;
+ TF_RETURN_IF_ERROR(
+ FunctionDefToBodyHelper(*fdef, n->attrs(), &lib,
+ [&lib](const string& op, const OpDef** sig) {
+ return lib.LookUpOpDef(op, sig);
+ },
+ &fbody));
+ // TODO(jpienaar): Improve this interface to make the need to delete it
+ // explicit.
+ InlineFunctionBody(g->flib_def(), g, n, fbody);
+ delete fbody;
+ return Status::OK();
+}
+
+Status CondBuilder::BuildLoweredIfOutput() {
+ // Build the identity node output.
+ NodeBuilder ib(name_, "IdentityN");
+ ib.Input(outputs_);
+ return ib.Finalize(graph_, &lowered_if_output_);
+}
+
+Status CondBuilder::InlineCallNodes() {
+ TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, graph_));
+ TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, graph_));
+ return Status::OK();
+}
+
+} // namespace
+
+Status LowerIfOpPass::Run(const GraphOptimizationPassOptions& options) {
+ if (options.partition_graphs != nullptr) {
+ return errors::Internal(
+ "Lowering If op should happen before partitioning.");
+ }
+ if (options.graph == nullptr) {
+ return Status::OK();
+ }
+
+ Graph* g = options.graph->get();
+ if (g == nullptr) {
+ return errors::Internal("Lowering If op requires a graph to be available.");
+ }
+
+ // Match all the nodes that need to be rewritten.
+ gtl::InlinedVector<Node*, 2> matches;
+ for (Node* n : g->op_nodes()) {
+ if (n->type_string() == "If") {
+ // Only rewrite if the If op is marked as needing to be lowered.
+ bool match;
+ Status s = GetNodeAttr(n->attrs(), kLowerUsingSwitchMergeAttr, &match);
+ if (s.ok() && match) matches.push_back(n);
+ }
+ }
+ for (Node* n : matches) {
+ TF_RETURN_IF_ERROR(RewriteNode(n, g));
+ }
+ return Status::OK();
+}
+
+Status LowerIfOpPass::RewriteNode(Node* n, Graph* g) {
+ const AttrValue* then_attr = n->attrs().Find("then_branch");
+ if (then_attr == nullptr) {
+ return errors::InvalidArgument("Then branch function missing");
+ }
+ const AttrValue* else_attr = n->attrs().Find("else_branch");
+ if (else_attr == nullptr) {
+ return errors::InvalidArgument("Else branch function missing");
+ }
+
+ CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), g);
+ TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
+ TF_RETURN_IF_ERROR(cb.AddInputs());
+ TF_RETURN_IF_ERROR(cb.AddOutputs());
+ TF_RETURN_IF_ERROR(cb.InlineCallNodes());
+ g->RemoveNode(n);
+
+ return Status::OK();
+}
+
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
+ LowerIfOpPass);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/lower_if_op.h b/tensorflow/core/common_runtime/lower_if_op.h
new file mode 100644
index 0000000000..a9ef39ae5c
--- /dev/null
+++ b/tensorflow/core/common_runtime/lower_if_op.h
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Rewrite If ops to use switch and merge nodes instead.
+class LowerIfOpPass : public GraphOptimizationPass {
+ public:
+ static const char* const kLowerUsingSwitchMergeAttr;
+
+ Status Run(const GraphOptimizationPassOptions& options) override;
+
+ private:
+ // Rewrite the given If node `n` in graph `g` to use the switch-merge form.
+ Status RewriteNode(Node* n, Graph* g);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_
diff --git a/tensorflow/core/common_runtime/lower_if_op_test.cc b/tensorflow/core/common_runtime/lower_if_op_test.cc
new file mode 100644
index 0000000000..319a617b32
--- /dev/null
+++ b/tensorflow/core/common_runtime/lower_if_op_test.cc
@@ -0,0 +1,140 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/lower_if_op.h"
+
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/graph_runner.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+Status Rewrite(std::unique_ptr<Graph>* graph) {
+ FunctionDefLibrary flib;
+ FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
+
+ GraphOptimizationPassOptions opt_options;
+ opt_options.graph = graph;
+ opt_options.flib_def = &flib_def;
+ LowerIfOpPass pass;
+ return pass.Run(opt_options);
+}
+
+TEST(LowerIfOpTest, Simple) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ // Add test functions for then and else branch.
+ FunctionDefLibrary f_lib_proto;
+ *(f_lib_proto.add_function()) = test::function::XTimesTwo();
+ *(f_lib_proto.add_function()) = test::function::XTimesFour();
+ FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
+
+ // Construct simple conditional that switches on `pred` and operates only on
+ // single input `A`.
+ Scope root = Scope::NewRootScope().ExitOnError();
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
+ auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
+ auto pred = ops::_Arg(root.WithOpName("pred"), DT_BOOL, 1);
+ Node* written_if;
+ std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
+ AttrValue tb;
+ tb.mutable_func()->set_name("XTimesTwo");
+ AttrValue eb;
+ eb.mutable_func()->set_name("XTimesFour");
+ TF_ASSERT_OK(NodeBuilder("if", "If", &f_lib)
+ .Input(pred.node())
+ .Input(inputs)
+ .Attr("then_branch", tb)
+ .Attr("else_branch", eb)
+ .Attr(LowerIfOpPass::kLowerUsingSwitchMergeAttr, true)
+ .Attr("Tout", {DT_INT32})
+ .Finalize(root.graph(), &written_if));
+ TF_ASSERT_OK(root.DoShapeInference(written_if));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ // The input graph has no switch or merge nodes.
+ int node_called_if_count = 0;
+ for (const auto* op : graph->op_nodes()) {
+ ASSERT_FALSE(op->IsSwitch());
+ ASSERT_FALSE(op->IsMerge());
+ if (op->name() == "if") {
+ ++node_called_if_count;
+ }
+ }
+ ASSERT_EQ(node_called_if_count, 1);
+
+ TF_ASSERT_OK(Rewrite(&graph));
+
+ // Verify the resultant graph has switch and merge nodes, and a node called
+ // `if` (but not If nodes).
+ int switch_count = 0;
+ int merge_count = 0;
+ node_called_if_count = 0;
+ for (const auto* op : graph->op_nodes()) {
+ if (op->IsSwitch()) {
+ ++switch_count;
+ }
+ if (op->IsMerge()) {
+ ++merge_count;
+ }
+ ASSERT_NE(op->type_string(), "If");
+ if (op->name() == "if") {
+ ++node_called_if_count;
+ }
+ }
+ // One switch for predicate and one for input (A).
+ ASSERT_EQ(switch_count, 2);
+ // One merge for the single output values of then and else.
+ ASSERT_EQ(merge_count, 1);
+ ASSERT_EQ(node_called_if_count, 1);
+
+ // Verify execution.
+ ClientSession session(root);
+ {
+ ClientSession::FeedType feeds;
+ feeds.emplace(Output(pred.node()), Input::Initializer(false));
+ feeds.emplace(Output(a.node()), Input::Initializer(10));
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
+ EXPECT_EQ(out_tensors.size(), 1);
+ EXPECT_EQ(out_tensors[0].scalar<int>()(), 40);
+ }
+ {
+ ClientSession::FeedType feeds;
+ feeds.emplace(Output(pred.node()), Input::Initializer(true));
+ feeds.emplace(Output(a.node()), Input::Initializer(10));
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
+ EXPECT_EQ(out_tensors.size(), 1);
+ EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index e61ed8c479..729312a310 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -144,7 +145,8 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext(
}
Device* device = flr->device();
string device_type = device->parsed_name().type;
- if (device_type == "CPU" || device_type == "TPU_SYSTEM") {
+ if (device_type == "CPU" || device_type == "TPU_SYSTEM" ||
+ device_type == "TPU") {
// "TPU_SYSTEM" indicates that `device` is a CPU.
return Status::OK();
}
@@ -182,8 +184,8 @@ FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
FunctionLibraryRuntime::LocalHandle local_handle) {
mutex_lock l(mu_);
auto h = next_handle_;
- FunctionData* fd = new FunctionData(device_name, local_handle);
- function_data_[h] = std::unique_ptr<FunctionData>(fd);
+ function_data_[h] = MakeUnique<FunctionData>(
+ device_name, local_handle, function_key);
table_[function_key] = h;
next_handle_++;
return h;
@@ -246,8 +248,8 @@ Status ProcessFunctionLibraryRuntime::Instantiate(
gtl::FindWithDefault(table_, function_key, kInvalidHandle);
if (h == kInvalidHandle || function_data_.count(h) == 0) {
h = next_handle_;
- FunctionData* fd = new FunctionData(options.target, kInvalidHandle);
- function_data_[h] = std::unique_ptr<FunctionData>(fd);
+ function_data_[h] = MakeUnique<FunctionData>(
+ options.target, kInvalidHandle, function_key);
table_[function_key] = h;
next_handle_++;
}
@@ -262,6 +264,14 @@ Status ProcessFunctionLibraryRuntime::Instantiate(
return Status::OK();
}
+Status ProcessFunctionLibraryRuntime::RemoveHandle(
+ FunctionLibraryRuntime::Handle handle) {
+ mutex_lock l(mu_);
+ table_.erase(function_data_[handle]->function_key());
+ function_data_.erase(handle);
+ return Status::OK();
+}
+
Status ProcessFunctionLibraryRuntime::ReleaseHandle(
FunctionLibraryRuntime::Handle handle) {
FunctionLibraryRuntime* flr = nullptr;
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 05e5770899..69381dd34d 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -134,6 +134,9 @@ class ProcessFunctionLibraryRuntime {
// of the device where the function is registered.
string GetDeviceName(FunctionLibraryRuntime::Handle handle);
+ // Removes handle from the state owned by this object.
+ Status RemoveHandle(FunctionLibraryRuntime::Handle handle);
+
Status Clone(Env* env, int graph_def_version,
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
@@ -147,10 +150,14 @@ class ProcessFunctionLibraryRuntime {
class FunctionData {
public:
FunctionData(const string& target_device,
- FunctionLibraryRuntime::LocalHandle local_handle)
- : target_device_(target_device), local_handle_(local_handle) {}
+ FunctionLibraryRuntime::LocalHandle local_handle,
+ const string& function_key)
+ : target_device_(target_device),
+ local_handle_(local_handle),
+ function_key_(function_key) {}
string target_device() { return target_device_; }
+ const string& function_key() { return function_key_; }
FunctionLibraryRuntime::LocalHandle local_handle() {
mutex_lock l(mu_);
@@ -169,6 +176,7 @@ class ProcessFunctionLibraryRuntime {
const string target_device_;
FunctionLibraryRuntime::LocalHandle local_handle_ GUARDED_BY(mu_);
+ const string function_key_;
bool init_started_ GUARDED_BY(mu_) = false;
Status init_result_ GUARDED_BY(mu_);
Notification init_done_;
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index cc10e77ad2..cce2308011 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -39,7 +39,7 @@ class TestClusterFLR : public DistributedFunctionLibraryRuntime {
Status Instantiate(const string& function_name,
const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
- FunctionLibraryRuntime::LocalHandle* handle) {
+ FunctionLibraryRuntime::LocalHandle* handle) override {
mutex_lock l(mu_);
*handle = next_handle_;
next_handle_++;
@@ -49,7 +49,7 @@ class TestClusterFLR : public DistributedFunctionLibraryRuntime {
void Run(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
- FunctionLibraryRuntime::DoneCallback done) {}
+ FunctionLibraryRuntime::DoneCallback done) override {}
private:
mutex mu_;
@@ -119,13 +119,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
EXPECT_GE(call_count, 1); // Test runner is used.
- // Release the handle and then try running the function. It
- // should still succeed.
+ // Release the handle and then try running the function. It shouldn't
+ // succeed.
status = proc_flr_->ReleaseHandle(handle);
if (!status.ok()) {
return status;
}
-
Notification done2;
proc_flr_->Run(opts, handle, args, &out,
[&status, &done2](const Status& s) {
@@ -133,7 +132,10 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
done2.Notify();
});
done2.WaitForNotification();
- return status;
+ EXPECT_TRUE(errors::IsNotFound(status));
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "not found."));
+
+ return Status::OK();
}
std::vector<Device*> devices_;
diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc
index a74c502a92..f8428f2fde 100644
--- a/tensorflow/core/common_runtime/ring_reducer.cc
+++ b/tensorflow/core/common_runtime/ring_reducer.cc
@@ -157,21 +157,27 @@ void RingReducer::Run(StatusCallback done) {
// we're not computing in-place on the input tensor.
if ((input_ != output_) &&
(DMAHelper::base(input_) != DMAHelper::base(output_))) {
+ // We are running in a blockable thread and the callback can't block so
+ // just wait here on the copy.
+ Notification note;
CollectiveRemoteAccessLocal::MemCpyAsync(
ctx_->input_device_context(0), ctx_->op_device_context(), device_,
device_, ctx_->input_alloc_attr(0), ctx_->output_alloc_attr(0), input_,
- output_, [this](const Status& s) {
- if (!s.ok()) {
- done_(s);
- } else {
- ContinueAfterInputCopy();
- }
+ output_, [this, &note, &status](const Status& s) {
+ status.Update(s);
+ note.Notify();
});
- } else {
- ContinueAfterInputCopy();
+ note.WaitForNotification();
+ if (!status.ok()) {
+ done_(status);
+ return;
+ }
}
+ ContinueAfterInputCopy();
}
+// Note that this function is blocking and must not run in any thread
+// which cannot be blocked.
void RingReducer::ContinueAfterInputCopy() {
AllocatorAttributes attr = ctx_->output_alloc_attr(0);
ca_.reset(MakeCollectiveAdapter(output_, group_size_ * num_subdivs_,
@@ -235,6 +241,7 @@ void RingReducer::Finish(bool ok) {
mutex_lock l(status_mu_);
s = status_;
}
+ rfv_.clear(); // Give up Refs on output tensor.
done_(s);
}
@@ -252,6 +259,7 @@ RingReducer::SubContext::SubContext(OpKernelContext* ctx,
sub_params_.input_device_contexts = &sub_input_dc_;
sub_params_.eigen_gpu_device = nullptr;
sub_params_.ensure_eigen_gpu_device();
+ sub_params_.forward_from_array = &forward_from_;
sub_ctx_ = new OpKernelContext(&sub_params_, 1);
}
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD
index 5fab740e92..1528c7f130 100644
--- a/tensorflow/core/debug/BUILD
+++ b/tensorflow/core/debug/BUILD
@@ -90,7 +90,6 @@ tf_cuda_library(
deps = [
":debug",
"//tensorflow/core:core_cpu_internal",
- "//tensorflow/core:device_tracer",
"//tensorflow/core:direct_session_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 256ce527a4..18b7069dbe 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -453,6 +453,40 @@ cc_library(
)
cc_library(
+ name = "collective_rma_distributed",
+ srcs = ["collective_rma_distributed.cc"],
+ hdrs = ["collective_rma_distributed.h"],
+ deps = [
+ ":worker_cache",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal", # protobuf::Any
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+tf_cc_test(
+ name = "collective_rma_distributed_test",
+ size = "small",
+ srcs = ["collective_rma_distributed_test.cc"],
+ deps = [
+ ":collective_rma_distributed",
+ ":device_resolver_distributed",
+ ":test_utils",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+cc_library(
name = "collective_param_resolver_distributed",
srcs = ["collective_param_resolver_distributed.cc"],
hdrs = ["collective_param_resolver_distributed.h"],
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
index ecf5db8110..7a93b54eae 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
@@ -284,7 +284,6 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed(
const GroupRecCallback& done) {
VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key
<< " dev: " << device << " is_leader=" << (group_leader_.empty());
- VLOG(0) << "cp: " << cp->ToString();
if (group_leader_.empty()) {
// This is the group leader, so resolution is local.
return CompleteGroupLocal(device, cp, done);
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
new file mode 100644
index 0000000000..c15878bfd3
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
@@ -0,0 +1,209 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
+
+#include "tensorflow/core/common_runtime/base_collective_executor.h"
+#include "tensorflow/core/common_runtime/copy_tensor.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/platform/protobuf_internal.h"
+#include "tensorflow/core/protobuf/transport_options.pb.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Supports client side cancellation of WorkerInterface calls via
+// registration with a CancellationManager.
+//
+// TODO(tucker): Maybe unify this with CancellableCall in
+// collective_param_resolver_distributed.cc.
+class CancellableCall {
+ public:
+ CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
+ WorkerCacheInterface* wc)
+ : cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc) {
+ wi_ = wc_->CreateWorker(remote_worker_);
+ }
+ virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); }
+
+ virtual void IssueCall(const StatusCallback& done) = 0;
+
+ void Start(const StatusCallback& done) {
+ CancellationToken token = cancel_mgr_->get_cancellation_token();
+ const bool not_yet_cancelled = cancel_mgr_->RegisterCallback(
+ token, [this, token]() { opts_.StartCancel(); });
+ if (not_yet_cancelled) {
+ IssueCall([this, token, done](const Status& s) {
+ cancel_mgr_->DeregisterCallback(token);
+ done(s);
+ });
+ } else {
+ done(errors::Cancelled("RPC Request was cancelled"));
+ }
+ }
+
+ protected:
+ mutable mutex mu_;
+ CancellationManager* cancel_mgr_; // Not owned
+ const string remote_worker_;
+ WorkerCacheInterface* wc_; // Not owned
+ WorkerInterface* wi_; // Owned by wc_, must be released.
+ CallOptions opts_;
+};
+
+class RecvBufCall : public CancellableCall {
+ public:
+ RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task,
+ const string& key, Device* to_device,
+ DeviceContext* to_device_ctx,
+ const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+ const DeviceLocality& client_locality,
+ const DeviceLocality& server_locality,
+ CancellationManager* cancel_mgr, WorkerCacheInterface* wc)
+ : CancellableCall(cancel_mgr, peer_task, wc) {
+ req_.set_step_id(step_id);
+ req_.set_buf_rendezvous_key(key);
+ *req_.mutable_client_locality() = client_locality;
+ *req_.mutable_server_locality() = server_locality;
+ req_.set_num_bytes(to_tensor->TotalBytes());
+ req_.set_buf_ptr(reinterpret_cast<int64>(DMAHelper::base(to_tensor)));
+ req_.set_src_device(peer_device);
+ req_.set_dst_device(to_device->name());
+ }
+
+ ~RecvBufCall() override {}
+
+ void IssueCall(const StatusCallback& done) override {
+ wi_->RecvBufAsync(&opts_, &req_, &resp_, done);
+ }
+
+ RecvBufRequest req_;
+ RecvBufResponse resp_;
+};
+
+} // namespace
+
+void CollectiveRemoteAccessDistributed::RecvFromPeer(
+ const string& peer_device, const string& peer_task, bool peer_is_local,
+ const string& key, Device* to_device, DeviceContext* to_device_ctx,
+ const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+ const DeviceLocality& client_locality, const StatusCallback& done) {
+ if (peer_is_local) {
+ CollectiveRemoteAccessLocal::RecvFromPeer(
+ peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
+ to_alloc_attr, to_tensor, client_locality, done);
+ return;
+ }
+
+ // State that needs to be threaded through a couple of async calls
+ // in order to make this function completely non-blocking.
+ struct State {
+ DeviceLocality server_locality;
+ std::unique_ptr<RecvBufCall> call;
+ };
+ State* state = new State;
+
+ // Logic to be executed on the RecvBufferAsync callback.
+ auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr,
+ to_device_ctx, to_tensor, done](const Status& s) {
+ if (s.ok()) {
+ // In this generic implementation the bytes come back in the
+ // RPC response protobuf rather than via RDMA so we need to copy
+ // them into the destination tensor here.
+ RecvBufRespExtra extra;
+ state->call->resp_.transport_options().UnpackTo(&extra);
+ int64 num_bytes = extra.tensor_content().size();
+ if (num_bytes != to_tensor->TotalBytes()) {
+ done(errors::Internal("RecvBufResponse returned ", num_bytes,
+ " bytes where to_tensor expected ",
+ to_tensor->TotalBytes()));
+ delete state;
+ return;
+ }
+ if (to_device->tensorflow_gpu_device_info()) {
+ // Move the bytes into a CPU tensor then use tensor-to-tensor copy.
+ // Use GPU-registered memory for the CPU tensor so the transfer
+ // goes faster.
+ Device* cpu_dev = nullptr;
+ Status status = dev_mgr_->LookupDevice("CPU:0", &cpu_dev);
+ if (!status.ok()) {
+ done(status);
+ delete state;
+ return;
+ }
+ AllocatorAttributes cpu_attr;
+ cpu_attr.set_gpu_compatible(true);
+ Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr),
+ to_tensor->dtype(), to_tensor->shape());
+ memcpy(DMAHelper::base(cpu_tensor), extra.tensor_content().data(),
+ num_bytes);
+ // Then copy it to the GPU.
+ CopyTensor::ViaDMA("", // edge name (non-existent)
+ nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev,
+ to_device, cpu_attr, to_alloc_attr, cpu_tensor,
+ to_tensor,
+ [this, cpu_tensor, done](const Status& s) {
+ delete cpu_tensor;
+ // This callback must not block, so execute
+ // done in another thread.
+ SchedClosure([s, done] { done(s); });
+ });
+ delete state;
+ return;
+ } else {
+ // CPU device
+ memcpy(DMAHelper::base(to_tensor), extra.tensor_content().data(),
+ num_bytes);
+ }
+ }
+ if (!s.ok() && errors::IsFailedPrecondition(s)) {
+ dev_resolver_->ClearTask(peer_task);
+ }
+
+ delete state;
+ done(s);
+ };
+
+ // Logic to execute once we have the device locality for the server-side
+ // device.
+ auto dev_locality_callback = [this, state, peer_device, peer_task, key,
+ to_device, to_device_ctx, to_alloc_attr,
+ to_tensor, client_locality,
+ recv_buf_callback](const Status& s) {
+ if (!s.ok()) {
+ recv_buf_callback(s);
+ } else {
+ state->call.reset(new RecvBufCall(
+ step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
+ to_alloc_attr, to_tensor, client_locality, state->server_locality,
+ &cancel_mgr_, worker_cache_));
+ state->call->Start(recv_buf_callback);
+ }
+ };
+
+ dev_resolver_->GetLocalityAsync(
+ peer_device, peer_task, &state->server_locality, dev_locality_callback);
+}
+
+void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) {
+ CollectiveRemoteAccessLocal::StartAbort(s);
+ cancel_mgr_.StartCancel();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
new file mode 100644
index 0000000000..cfa9110f47
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
@@ -0,0 +1,50 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+class WorkerCacheInterface;
+
+// Extend CollectiveRemoteAccessLocal with access to remote peers.
+class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
+ public:
+ CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr,
+ DeviceResolverInterface* dev_resolver,
+ WorkerCacheInterface* worker_cache,
+ int64 step_id)
+ : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
+ worker_cache_(worker_cache) {}
+
+ ~CollectiveRemoteAccessDistributed() override {}
+
+ void RecvFromPeer(const string& peer_device, const string& peer_task,
+ bool peer_is_local, const string& key, Device* to_device,
+ DeviceContext* to_device_ctx,
+ const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
+ const DeviceLocality& client_locality,
+ const StatusCallback& done) override;
+
+ void StartAbort(const Status& s) override;
+
+ protected:
+ WorkerCacheInterface* worker_cache_; // Not owned
+ CancellationManager cancel_mgr_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
new file mode 100644
index 0000000000..a552f81f58
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
@@ -0,0 +1,356 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
+
+#include "google/protobuf/any.pb.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/test_utils.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/transport_options.pb.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+// The only interesting method on CollectiveRemoteAccessDistributed
+// that's not on CollectiveRemoteAccessLocal is RecvFromPeer which
+// issues a RecvBufAsync call against a WorkerInterface. That's all
+// that's tested here. Note that RecvFromPeer can do a
+// DeviceResolverInterface::GetDeviceLocalityAsync call in preparation
+// for the RecvBufAsync.
+
+namespace tensorflow {
+namespace {
+
+static Device* NewDevice(const string& type, const string& name) {
+ class FakeDevice : public Device {
+ public:
+ explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
+ Status Sync() override { return Status::OK(); }
+ Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
+ };
+ DeviceAttributes attr;
+ attr.set_name(name);
+ attr.set_device_type(type);
+ attr.mutable_locality()->set_numa_node(3); // a non-default value
+ return new FakeDevice(attr);
+}
+
+static int64 kStepId = 123;
+
+class FakeWorker : public TestWorkerInterface {
+ public:
+ FakeWorker(const string& name, DeviceMgr* dev_mgr,
+ DeviceResolverDistributed* dres)
+ : name_(name),
+ device_mgr_(dev_mgr),
+ device_resolver_(dres),
+ buf_rendezvous_(kStepId) {}
+
+ // Direct access to a BufRendezvous that holds whatever the remote
+ // worker is supposed to have.
+ BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; }
+
+ void GetStatusAsync(const GetStatusRequest* request,
+ GetStatusResponse* response,
+ StatusCallback done) override {
+ std::vector<DeviceAttributes> dev_attr;
+ device_mgr_->ListDeviceAttributes(&dev_attr);
+ for (const auto& da : dev_attr) {
+ *response->add_device_attributes() = da;
+ }
+ done(Status::OK());
+ }
+
+ void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) override {
+ opts->SetCancelCallback([this]() {
+ // Within this test the call is satisfied by a process-local
+ // BufRendezvous table. In real application the BufRendezvous
+ // would be on the other side of a network hop, so call
+ // BufRendezvous::StartAbort() from a separate thread to be
+ // more consistent with that situation and avoid mutex deadlock.
+ SchedClosure([this]() {
+ Env::Default()->SleepForMicroseconds(100);
+ buf_rendezvous_.StartAbort(errors::Internal("Cancelled"));
+ });
+ });
+ buf_rendezvous_.ConsumeBuf(
+ request->buf_rendezvous_key(),
+ [this, opts, request, response, done](const Status& s,
+ BufRendezvous::Hook* h) {
+ if (s.ok()) {
+ opts->ClearCancelCallback();
+ // Since this is not really RDMA into pre-allocated memory send the
+ // bytes in the response.
+ RecvBufRespExtra extra;
+ int64 num_bytes = h->prod_value->TotalBytes();
+ extra.set_tensor_content(string(
+ reinterpret_cast<const char*>(DMAHelper::base(h->prod_value)),
+ num_bytes));
+ response->mutable_transport_options()->PackFrom(extra);
+ }
+ done(s);
+ if (h) BufRendezvous::DoneWithHook(h);
+ });
+ }
+
+ private:
+ string name_;
+ DeviceMgr* device_mgr_;
+ DeviceResolverDistributed* device_resolver_;
+ BufRendezvous buf_rendezvous_;
+};
+
+class FakeCache : public TestWorkerCache {
+ public:
+ // Override the Locality methods to actually pass through to the
+ // worker.
+ bool GetDeviceLocalityNonBlocking(const string& device,
+ DeviceLocality* locality) override {
+ return false;
+ }
+
+ void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
+ StatusCallback done) override {
+ string task_name;
+ string dev_part;
+ if (!DeviceNameUtils::SplitDeviceName(device, &task_name, &dev_part)) {
+ done(errors::Internal("failed to parse device name"));
+ return;
+ }
+ auto it = workers_.find(task_name);
+ if (it == workers_.end()) {
+ done(errors::Internal("failed to find worker ", task_name));
+ return;
+ }
+ WorkerInterface* wi = it->second;
+ GetStatusRequest req;
+ GetStatusResponse resp;
+ Notification note;
+ Status status = wi->GetStatus(&req, &resp);
+ if (!status.ok()) {
+ done(status);
+ return;
+ }
+ for (const auto& it : resp.device_attributes()) {
+ if (it.name() == device) {
+ *locality = it.locality();
+ done(Status::OK());
+ return;
+ }
+ }
+ done(errors::Internal("device not found: ", device));
+ }
+};
+
+class CollRMADistTest : public ::testing::Test {
+ protected:
+ CollRMADistTest() {}
+
+ ~CollRMADistTest() override {
+ for (DeviceMgr* dm : device_mgrs_) {
+ delete dm;
+ }
+ for (auto it : dev_resolvers_) {
+ delete it.second;
+ }
+ for (FakeWorker* w : workers_) {
+ delete w;
+ }
+ }
+
+ void SetUp() override {
+ const int num_workers = 2;
+ const int num_devices = 1;
+ string device_type = "CPU";
+ ConfigProto config;
+ string dev0_worker_name;
+ for (int w = 0; w < num_workers; ++w) {
+ string name = strings::StrCat("/job:worker/replica:0/task:", w);
+ if (w == 0) {
+ dev0_worker_name = name;
+ // TODO(tucker): Change to use config when available.
+ // config.set_collective_group_leader(name);
+ }
+ DefineWorker(config, name, device_type, num_devices);
+ }
+ // All tests simulate requests from worker 0 to worker 1.
+ rma_.reset(new CollectiveRemoteAccessDistributed(
+ device_mgrs_[0], dev_resolvers_[dev0_worker_name], &wc_, kStepId));
+
+ const int kNumElts = 8;
+ expected_value_ = Tensor(DT_FLOAT, {kNumElts});
+ to_tensor_ = Tensor(DT_FLOAT, {kNumElts});
+ auto exp_alias = expected_value_.flat<float>();
+ auto to_alias = to_tensor_.flat<float>();
+ for (int i = 0; i < kNumElts; ++i) {
+ exp_alias(i) = i;
+ to_alias(i) = -1;
+ }
+ }
+
+ void DefineWorker(const ConfigProto& config, const string& worker_name,
+ const string& device_type, int num_devices) {
+ std::vector<Device*> devices;
+ for (int i = 0; i < num_devices; ++i) {
+ devices.push_back(NewDevice(
+ device_type,
+ strings::StrCat(worker_name, "/device:", device_type, ":", i)));
+ }
+ DeviceMgr* dev_mgr = new DeviceMgr(devices);
+ device_mgrs_.push_back(dev_mgr);
+ std::vector<string>* dv = &dev_by_task_[worker_name];
+ for (auto d : devices) {
+ dv->push_back(d->name());
+ }
+ DeviceResolverDistributed* dev_res =
+ new DeviceResolverDistributed(dev_mgr, &wc_, worker_name);
+ dev_resolvers_[worker_name] = dev_res;
+ FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);
+ workers_.push_back(fw);
+ wc_.AddWorker(worker_name, fw);
+ }
+
+ void ValidateResultTensor() {
+ ASSERT_EQ(expected_value_.NumElements(), to_tensor_.NumElements());
+ for (int i = 0; i < to_tensor_.NumElements(); ++i) {
+ EXPECT_FLOAT_EQ(expected_value_.flat<float>()(i),
+ to_tensor_.flat<float>()(i));
+ }
+ }
+
+ FakeCache wc_;
+ CancellationManager cm_;
+ std::vector<DeviceMgr*> device_mgrs_;
+ std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
+ std::unordered_map<string, std::vector<string>> dev_by_task_;
+ std::vector<FakeWorker*> workers_;
+ std::unique_ptr<CollectiveRemoteAccessDistributed> rma_;
+ mutex mu_;
+ int num_done_ GUARDED_BY(mu_);
+ condition_variable done_;
+ Tensor expected_value_;
+ Tensor to_tensor_;
+ CallOptions opts_;
+ DeviceLocality device_locality_;
+ AllocatorAttributes alloc_attr_;
+};
+
+TEST_F(CollRMADistTest, ProdFirstOK) {
+ Notification consumer_note;
+ Notification producer_note;
+ Status consumer_status;
+ Status producer_status;
+ FakeWorker* wi = workers_[1];
+ const string kBufKey = "fake_buf_key";
+ wi->buf_rendezvous()->ProvideBuf(
+ kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
+ AllocatorAttributes(),
+ [this, &producer_note, &producer_status](const Status& s) {
+ producer_status.Update(s);
+ producer_note.Notify();
+ });
+ Status status;
+ Device* dst_device = nullptr;
+ string dev_name = "CPU:0";
+ TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+ DeviceContext* to_device_ctx = nullptr;
+ rma_->RecvFromPeer(
+ "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
+ "/job:worker/replica:0/task:1", // peer_task
+ false, // peer_is_local
+ kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
+ device_locality_,
+ [this, &consumer_status, &consumer_note](const Status& s) {
+ consumer_status = s;
+ consumer_note.Notify();
+ });
+ consumer_note.WaitForNotification();
+ TF_EXPECT_OK(consumer_status);
+ producer_note.WaitForNotification();
+ TF_EXPECT_OK(producer_status);
+ ValidateResultTensor();
+}
+
+TEST_F(CollRMADistTest, ConsFirstOK) {
+ Notification consumer_note;
+ Notification producer_note;
+ Status consumer_status;
+ Status producer_status;
+ FakeWorker* wi = workers_[1];
+ const string kBufKey = "fake_buf_key";
+ Status status;
+ Device* dst_device = nullptr;
+ string dev_name = "CPU:0";
+ TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+ DeviceContext* to_device_ctx = nullptr;
+ rma_->RecvFromPeer(
+ "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
+ "/job:worker/replica:0/task:1", // peer_task
+ false, // peer_is_local
+ kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
+ device_locality_,
+ [this, &consumer_status, &consumer_note](const Status& s) {
+ consumer_status = s;
+ consumer_note.Notify();
+ });
+ wi->buf_rendezvous()->ProvideBuf(
+ kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
+ AllocatorAttributes(),
+ [this, &producer_note, &producer_status](const Status& s) {
+ producer_status.Update(s);
+ producer_note.Notify();
+ });
+ consumer_note.WaitForNotification();
+ TF_EXPECT_OK(consumer_status);
+ producer_note.WaitForNotification();
+ TF_EXPECT_OK(producer_status);
+ ValidateResultTensor();
+}
+
+TEST_F(CollRMADistTest, ConsFirstAbort) {
+ Notification consumer_note;
+ Status consumer_status;
+ const string kBufKey = "fake_buf_key";
+ Status status;
+ Device* dst_device = nullptr;
+ string dev_name = "CPU:0";
+ TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+ DeviceContext* to_device_ctx = nullptr;
+ rma_->RecvFromPeer(
+ "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
+ "/job:worker/replica:0/task:1", // peer_task
+ false, // peer_is_local
+ kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
+ device_locality_,
+ [this, &consumer_status, &consumer_note](const Status& s) {
+ consumer_status = s;
+ consumer_note.Notify();
+ });
+ rma_->StartAbort(errors::Internal("Deliberate Failure"));
+ consumer_note.WaitForNotification();
+ EXPECT_EQ(consumer_status.error_message(), "Cancelled");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index c2719f5462..40028ee241 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -171,6 +171,7 @@ tf_cuda_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc",
"//tensorflow/core/distributed_runtime:graph_mgr",
"//tensorflow/core/distributed_runtime:recent_request_ids",
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
index 5b7b74ce63..1acf1fb4fc 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
@@ -54,6 +54,7 @@ class GrpcRemoteWorker : public WorkerInterface {
cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
+ recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
logging_(Method(GrpcWorkerMethod::kLogging)),
tracing_(Method(GrpcWorkerMethod::kTracing)),
completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
@@ -118,6 +119,11 @@ class GrpcRemoteWorker : public WorkerInterface {
IssueRequest(request, response, cleanupall_, std::move(done));
}
+ void RecvBufAsync(CallOptions* call_opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) override {
+ IssueRequest(request, response, recvbuf_, std::move(done), call_opts);
+ }
+
void CompleteGroupAsync(CallOptions* call_opts,
const CompleteGroupRequest* request,
CompleteGroupResponse* response,
@@ -239,6 +245,7 @@ class GrpcRemoteWorker : public WorkerInterface {
const ::grpc::string cleanupgraph_;
const ::grpc::string cleanupall_;
const ::grpc::string recvtensor_;
+ const ::grpc::string recvbuf_;
const ::grpc::string logging_;
const ::grpc::string tracing_;
const ::grpc::string completegroup_;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
index 26fad1fc3c..137eb4a635 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "grpc++/alarm.h"
#include "grpc++/server_builder.h"
+#include "tensorflow/core/common_runtime/buf_rendezvous.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@@ -37,10 +38,12 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/protobuf/transport_options.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
namespace tensorflow {
@@ -159,6 +162,9 @@ class GrpcWorkerService : public AsyncServiceInterface {
for (int i = 0; i < 1000; ++i) {
EnqueueRecvTensorRequestRaw();
}
+ for (int i = 0; i < 500; ++i) {
+ ENQUEUE_REQUEST(RecvBuf, true);
+ }
for (int i = 0; i < 100; ++i) {
ENQUEUE_REQUEST(RunGraph, true);
}
@@ -170,9 +176,9 @@ class GrpcWorkerService : public AsyncServiceInterface {
ENQUEUE_REQUEST(Tracing, false);
for (int i = 0; i < 10; ++i) {
- ENQUEUE_REQUEST(CompleteGroup, false);
- ENQUEUE_REQUEST(CompleteInstance, false);
- ENQUEUE_REQUEST(GetStepSequence, false);
+ ENQUEUE_REQUEST(CompleteGroup, true);
+ ENQUEUE_REQUEST(CompleteInstance, true);
+ ENQUEUE_REQUEST(GetStepSequence, true);
}
void* tag;
@@ -322,6 +328,20 @@ class GrpcWorkerService : public AsyncServiceInterface {
ENQUEUE_REQUEST(Tracing, false);
}
+ void RecvBufHandler(WorkerCall<RecvBufRequest, RecvBufResponse>* call) {
+ Schedule([this, call]() {
+ CallOptions* call_opts = new CallOptions;
+ call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
+ worker_->RecvBufAsync(call_opts, &call->request, &call->response,
+ [call, call_opts](const Status& s) {
+ call->ClearCancelCallback();
+ delete call_opts;
+ call->SendResponse(ToGrpcStatus(s));
+ });
+ });
+ ENQUEUE_REQUEST(RecvBuf, true);
+ }
+
void CompleteGroupHandler(
WorkerCall<CompleteGroupRequest, CompleteGroupResponse>* call) {
Schedule([this, call]() {
@@ -334,7 +354,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
call->SendResponse(ToGrpcStatus(s));
});
});
- ENQUEUE_REQUEST(CompleteGroup, false);
+ ENQUEUE_REQUEST(CompleteGroup, true);
}
void CompleteInstanceHandler(
@@ -360,7 +380,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
&call->request, &call->response,
[call](const Status& s) { call->SendResponse(ToGrpcStatus(s)); });
});
- ENQUEUE_REQUEST(GetStepSequence, false);
+ ENQUEUE_REQUEST(GetStepSequence, true);
}
#undef ENQUEUE_REQUEST
@@ -485,11 +505,79 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
});
}
+void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) {
+ // This is a generic, low performance implementation appropriate for grpc.
+ CollectiveExecutor::Handle ce_handle(
+ env_->collective_executor_mgr->FindOrCreate(request->step_id()), true);
+ CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
+ rma->buf_rendezvous()->ConsumeBuf(
+ request->buf_rendezvous_key(),
+ [this, opts, request, response, done](const Status& status,
+ BufRendezvous::Hook* hook) {
+ Status s = status;
+ if (s.ok()) {
+ if (!DMAHelper::CanUseDMA(hook->prod_value)) {
+ s = errors::Internal("Tensor value for key ",
+ request->buf_rendezvous_key(),
+ " is not of a type supported by RecvBuf");
+ }
+ }
+ if (s.ok()) {
+ // The RPC source tensor needs to be in CPU RAM. If not already
+ // there make a copy using memory appropriate to the purpose.
+ const size_t num_bytes = hook->prod_value->TotalBytes();
+ const bool on_host =
+ hook->prod_dev->attributes().device_type() == "CPU" ||
+ hook->prod_attr.on_host();
+ if ((!on_host) && (num_bytes > 0)) {
+ Device* cpu_dev = nullptr;
+ s = env_->device_mgr->LookupDevice("CPU:0", &cpu_dev);
+ if (s.ok()) {
+ AllocatorAttributes cpu_attr;
+ cpu_attr.set_gpu_compatible(true);
+ cpu_attr.set_nic_compatible(true);
+ Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr),
+ hook->prod_value->dtype(),
+ hook->prod_value->shape());
+ hook->prod_ctx->CopyDeviceTensorToCPU(
+ hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor,
+ [this, num_bytes, response, done, hook,
+ cpu_tensor](const Status& s) {
+ if (s.ok()) {
+ RecvBufRespExtra extra;
+ extra.set_tensor_content(reinterpret_cast<const char*>(
+ DMAHelper::base(cpu_tensor)),
+ num_bytes);
+ response->mutable_transport_options()->PackFrom(extra);
+ }
+ response->set_send_start_micros(env_->env->NowMicros());
+ done(s);
+ BufRendezvous::DoneWithHook(hook);
+ delete cpu_tensor;
+ });
+ return;
+ }
+ } else {
+ // Tensor is on CPU.
+ RecvBufRespExtra extra;
+ extra.set_tensor_content(reinterpret_cast<const char*>(
+ DMAHelper::base(hook->prod_value)),
+ num_bytes);
+ response->mutable_transport_options()->PackFrom(extra);
+ }
+ }
+ response->set_send_start_micros(env_->env->NowMicros());
+ done(s);
+ BufRendezvous::DoneWithHook(hook);
+ });
+}
+
void GrpcWorker::LoggingAsync(const LoggingRequest* request,
LoggingResponse* response, StatusCallback done) {
auto env = this->env();
if (env) {
- auto session_mgr = (SessionMgr*)env->session_mgr;
+ auto session_mgr = env->session_mgr;
if (session_mgr) {
session_mgr->SetLogging(request->rpc_logging());
for (const auto& step_id : request->fetch_step_id()) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
index fbddbda9e6..c0ed0884bc 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
@@ -43,6 +43,9 @@ class GrpcWorker : public Worker {
virtual void LoggingAsync(const LoggingRequest* request,
LoggingResponse* response, StatusCallback done);
+ virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done);
+
WorkerEnv* env();
private:
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
index a91cc0692a..38cc2b81d3 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
@@ -46,6 +46,8 @@ const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
return "/tensorflow.WorkerService/CleanupAll";
case GrpcWorkerMethod::kRecvTensor:
return "/tensorflow.WorkerService/RecvTensor";
+ case GrpcWorkerMethod::kRecvBuf:
+ return "/tensorflow.WorkerService/RecvBuf";
case GrpcWorkerMethod::kLogging:
return "/tensorflow.WorkerService/Logging";
case GrpcWorkerMethod::kTracing:
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
index c5104c6a50..da270835bd 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
@@ -81,6 +81,7 @@ enum class GrpcWorkerMethod {
kCleanupGraph,
kCleanupAll,
kRecvTensor,
+ kRecvBuf,
kLogging,
kTracing,
kCompleteGroup,
diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc
index 7ef4206c78..95b31c6991 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr.cc
@@ -67,7 +67,7 @@ Status SessionMgr::CreateSession(const string& session,
worker_name = WorkerNameFromServerDef(server_def);
}
- if (worker_cache != nullptr & default_worker_cache_.get() != nullptr) {
+ if (worker_cache != nullptr && default_worker_cache_ != nullptr) {
worker_cache->SetLogging(this->is_logging_active_);
}
diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h
index 0ed078241f..48d83845dd 100644
--- a/tensorflow/core/distributed_runtime/test_utils.h
+++ b/tensorflow/core/distributed_runtime/test_utils.h
@@ -93,6 +93,11 @@ class TestWorkerInterface : public WorkerInterface {
done(errors::Unimplemented("RunGraphAsync"));
}
+ void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) override {
+ done(errors::Unimplemented("RecvBufAsync"));
+ }
+
void CompleteGroupAsync(CallOptions* opts,
const CompleteGroupRequest* request,
CompleteGroupResponse* response,
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index d682ac8f34..4e6500fbc6 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -337,6 +337,15 @@ void Worker::TracingAsync(const TracingRequest* request,
done(errors::Unimplemented("Tracing"));
}
+void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) {
+ // The base Worker class does not implement RecvBufAsync because
+ // it is not currently used for worker-to-worker communication. Use a
+ // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`)
+ // instead.
+ done(errors::Unimplemented("Worker::RecvBufAsync()"));
+}
+
void Worker::CompleteGroupAsync(CallOptions* opts,
const CompleteGroupRequest* request,
CompleteGroupResponse* response,
diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h
index b5a9ada502..91eb27c10e 100644
--- a/tensorflow/core/distributed_runtime/worker.h
+++ b/tensorflow/core/distributed_runtime/worker.h
@@ -90,6 +90,9 @@ class Worker : public WorkerInterface {
void TracingAsync(const TracingRequest* request, TracingResponse* response,
StatusCallback done) override;
+ void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) override;
+
void CompleteGroupAsync(CallOptions* opts,
const CompleteGroupRequest* request,
CompleteGroupResponse* response,
diff --git a/tensorflow/core/distributed_runtime/worker_cache_partial.cc b/tensorflow/core/distributed_runtime/worker_cache_partial.cc
index 61e5416234..55b6957b96 100644
--- a/tensorflow/core/distributed_runtime/worker_cache_partial.cc
+++ b/tensorflow/core/distributed_runtime/worker_cache_partial.cc
@@ -67,7 +67,7 @@ Status WorkerCachePartial::RefreshDeviceStatus(const string& device_name) {
};
std::unique_ptr<WorkerInterface, decltype(deleter)> rwi(CreateWorker(task),
deleter);
- if (s.ok() && !rwi.get()) {
+ if (s.ok() && !rwi) {
s = errors::Internal("RefreshDeviceStatus, unknown worker task: ", task);
}
diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h
index bad31d27b2..a50ac3b8ae 100644
--- a/tensorflow/core/distributed_runtime/worker_interface.h
+++ b/tensorflow/core/distributed_runtime/worker_interface.h
@@ -112,6 +112,9 @@ class WorkerInterface {
virtual void TracingAsync(const TracingRequest* request,
TracingResponse* response, StatusCallback done) = 0;
+ virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
+ RecvBufResponse* response, StatusCallback done) = 0;
+
virtual void CompleteGroupAsync(CallOptions* opts,
const CompleteGroupRequest* request,
CompleteGroupResponse* response,
diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc
index 87c1ddd15d..79966f0692 100644
--- a/tensorflow/core/framework/attr_value_util.cc
+++ b/tensorflow/core/framework/attr_value_util.cc
@@ -33,6 +33,154 @@ limitations under the License.
namespace tensorflow {
namespace {
+// Do not construct large tensors to compute their hash or compare for equality.
+constexpr int kMaxAttrValueTensorByteSize = 32 * 1024 * 1024; // 32mb
+
+// Return the size of the tensor represented by this TensorProto. If shape is
+// not fully defined return -1.
+int64 TensorByteSize(const TensorProto& t) {
+ // num_elements returns -1 if shape is not fully defined.
+ int64 num_elems = TensorShape(t.tensor_shape()).num_elements();
+ return num_elems < 0 ? -1 : num_elems * DataTypeSize(t.dtype());
+}
+
+// Compute TensorProto hash by creating a Tensor, serializing it as tensor
+// content, and computing a hash of it's string representation. This is unsafe
+// operation, because large tensors can be represented as TensorProto, but can't
+// be serialized to tensor content.
+uint64 TensorProtoHash(const TensorProto& tp) {
+ Tensor tensor(tp.dtype());
+ bool success = tensor.FromProto(tp);
+ DCHECK(success);
+ TensorProto p;
+ tensor.AsProtoTensorContent(&p);
+ string s;
+ SerializeToStringDeterministic(p, &s);
+ return Hash64(s);
+}
+
+// Do not create large tensors in memory, compute hash based on TensorProto
+// string representation. Tensors with identical content potentially can have a
+// different hash code if they are defined with different TensorProto
+// representations.
+uint64 FastTensorProtoHash(const TensorProto& tp) {
+ string s;
+ if (TensorByteSize(tp) > kMaxAttrValueTensorByteSize) {
+ string s;
+ bool success = SerializeToStringDeterministic(tp, &s);
+ DCHECK(success);
+ return Hash64(s);
+ } else {
+ return TensorProtoHash(tp);
+ }
+}
+
+// There are multiple equivalent representations of attr values containing
+// TensorProtos. Compare them by constructing Tensors and serializing them
+// back. Comparing Tensor objects is pretty tricky. This is unsafe operation,
+// because large tensors can be represented as TensorProto, but can't be
+// serialized to tensor content.
+bool AreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) {
+ Tensor lhs_t(lhs.dtype());
+ bool success = lhs_t.FromProto(lhs);
+ DCHECK(success);
+
+ Tensor rhs_t(rhs.dtype());
+ success = rhs_t.FromProto(rhs);
+ DCHECK(success);
+
+ TensorProto lhs_tp;
+ lhs_t.AsProtoTensorContent(&lhs_tp);
+
+ TensorProto rhs_tp;
+ rhs_t.AsProtoTensorContent(&rhs_tp);
+
+ string lhs_str, rhs_str;
+ SerializeToStringDeterministic(lhs_tp, &lhs_str);
+ SerializeToStringDeterministic(rhs_tp, &rhs_str);
+
+ return lhs_str == rhs_str;
+}
+
+// Do not construct large tensors in memory, compare equality using TensorProto
+// string representation. Tensors with identical content potentially can have
+// different tensor proto representation.
+bool FastAreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) {
+ if (TensorByteSize(lhs) > kMaxAttrValueTensorByteSize ||
+ TensorByteSize(rhs) > kMaxAttrValueTensorByteSize) {
+ string lhs_str, rhs_str;
+ bool success = lhs.AppendToString(&lhs_str);
+ DCHECK(success);
+ success = rhs.AppendToString(&rhs_str);
+ DCHECK(success);
+
+ return lhs_str == rhs_str;
+ } else {
+ return AreTensorProtosEqual(lhs, rhs);
+ }
+}
+
+using TensorProtoHasher = std::function<uint64(const TensorProto&)>;
+using TensorProtosEquality =
+ std::function<bool(const TensorProto&, const TensorProto&)>;
+
+uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
+ if (a.has_tensor()) return tensor_hash(a.tensor());
+
+ if (a.has_func()) {
+ const NameAttrList& func = a.func();
+ uint64 h = Hash64(func.name());
+ std::map<string, AttrValue> map(func.attr().begin(), func.attr().end());
+ for (const auto& pair : map) {
+ h = Hash64(pair.first.data(), pair.first.size(), h);
+ h = Hash64Combine(AttrValueHash(pair.second, tensor_hash), h);
+ }
+ return h;
+ }
+
+ // If `a` is not a tensor or func, get a hash of serialized string.
+ string s;
+ SerializeToStringDeterministic(a, &s);
+ return Hash64(s);
+}
+
+bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
+ const TensorProtosEquality& tensor_equality) {
+ if (a.has_tensor() != b.has_tensor()) {
+ return false;
+ } else if (a.has_tensor() && b.has_tensor()) {
+ return tensor_equality(a.tensor(), b.tensor());
+ }
+
+ // `func` field contains a nested AttrValue. Compare such AttrValues
+ // recursively.
+ if (a.has_func() != b.has_func()) {
+ return false;
+ } else if (a.has_func() && b.has_func()) {
+ const NameAttrList& af = a.func();
+ const NameAttrList& bf = b.func();
+ if (af.name() != bf.name()) return false;
+ std::unordered_map<string, AttrValue> am(af.attr().begin(),
+ af.attr().end());
+ for (const auto& bm_pair : bf.attr()) {
+ const auto& iter = am.find(bm_pair.first);
+ if (iter == am.end()) return false;
+ if (!AreAttrValuesEqual(iter->second, bm_pair.second, tensor_equality))
+ return false;
+ am.erase(iter);
+ }
+ if (!am.empty()) return false;
+ return true;
+ }
+
+ // All other fields in AttrValue have deterministic representations.
+ // It is safe to compare their serialized strings.
+ string a_str, b_str;
+ SerializeToStringDeterministic(a, &a_str);
+ SerializeToStringDeterministic(b, &b_str);
+ return a_str == b_str;
+}
+
string SummarizeString(const string& str) {
string escaped = str_util::CEscape(str);
@@ -412,89 +560,19 @@ void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) {
}
bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
- // There are multiple equivalent representations of attr values containing
- // TensorProtos. Compare them by constructing Tensors and serializing them
- // back. Comparing Tensor objects is pretty tricky.
- if (a.has_tensor() != b.has_tensor()) {
- return false;
- } else if (a.has_tensor() && b.has_tensor()) {
- Tensor at(a.tensor().dtype());
- bool success = at.FromProto(a.tensor());
- DCHECK(success);
-
- Tensor bt(b.tensor().dtype());
- success = bt.FromProto(b.tensor());
- DCHECK(success);
-
- TensorProto ap;
- at.AsProtoTensorContent(&ap);
-
- TensorProto bp;
- bt.AsProtoTensorContent(&bp);
-
- string a_str, b_str;
- SerializeToStringDeterministic(ap, &a_str);
- SerializeToStringDeterministic(bp, &b_str);
- return a_str == b_str;
- }
-
- // `func` field contains a nested AttrValue. Compare such AttrValues
- // recursively.
- if (a.has_func() != b.has_func()) {
- return false;
- } else if (a.has_func() && b.has_func()) {
- const NameAttrList& af = a.func();
- const NameAttrList& bf = b.func();
- if (af.name() != bf.name()) return false;
- std::unordered_map<string, AttrValue> am(af.attr().begin(),
- af.attr().end());
- for (const auto& bm_pair : bf.attr()) {
- const auto& iter = am.find(bm_pair.first);
- if (iter == am.end()) return false;
- if (!AreAttrValuesEqual(iter->second, bm_pair.second)) return false;
- am.erase(iter);
- }
- if (!am.empty()) return false;
- return true;
- }
-
- // All other fields in AttrValue have deterministic representations.
- // It is safe to compare their serialized strings.
- string a_str, b_str;
- SerializeToStringDeterministic(a, &a_str);
- SerializeToStringDeterministic(b, &b_str);
- return a_str == b_str;
+ return AreAttrValuesEqual(a, b, AreTensorProtosEqual);
}
uint64 AttrValueHash(const AttrValue& a) {
- if (a.has_tensor()) {
- // Deal with multiple representations by parsing TensorProto to
- // Tensor and serializing it back. This is slow, but current use case
- // don't need high efficiency.
- Tensor tensor(a.tensor().dtype());
- bool success = tensor.FromProto(a.tensor());
- DCHECK(success);
- TensorProto p;
- tensor.AsProtoTensorContent(&p);
- string s;
- SerializeToStringDeterministic(p, &s);
- return Hash64(s);
- }
- if (a.has_func()) {
- const NameAttrList& func = a.func();
- uint64 h = Hash64(func.name());
- std::map<string, AttrValue> map(func.attr().begin(), func.attr().end());
- for (const auto& pair : map) {
- h = Hash64(pair.first.data(), pair.first.size(), h);
- h = Hash64Combine(AttrValueHash(pair.second), h);
- }
- return h;
- }
+ return AttrValueHash(a, TensorProtoHash);
+}
- // If `a` is not a tensor or func, get a hash of serialized string.
- string s;
- SerializeToStringDeterministic(a, &s);
- return Hash64(s);
+bool FastAreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
+ return AreAttrValuesEqual(a, b, FastAreTensorProtosEqual);
+}
+
+uint64 FastAttrValueHash(const AttrValue& a) {
+ return AttrValueHash(a, FastTensorProtoHash);
}
bool HasPlaceHolder(const AttrValue& val) {
diff --git a/tensorflow/core/framework/attr_value_util.h b/tensorflow/core/framework/attr_value_util.h
index 29e34c5090..0da9b1081b 100644
--- a/tensorflow/core/framework/attr_value_util.h
+++ b/tensorflow/core/framework/attr_value_util.h
@@ -98,6 +98,19 @@ bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b);
// probably not persist the returned value.
uint64 AttrValueHash(const AttrValue& a);
+// WARNING: Equality check might return false-negative for large (> 32mb)
+// tensors defined with different TensorProto representations.
+//
+// A pair of consistent hash and equals functions that are guaranteed to be fast
+// with AttrValues that potentially can have very large Tensors (larger than
+// 32mb) defined by TensorProto. If large identical Tensors are defined using
+// different representations (e.g. one with tensor content, and second with
+// bool_val), they will have different hash code and equals will return false.
+// Small (less than 32mb) tensors with different TensorProto representations
+// hashed/compared by their tensor content.
+uint64 FastAttrValueHash(const AttrValue& a);
+bool FastAreAttrValuesEqual(const AttrValue& a, const AttrValue& b);
+
// Returns true if "val" has a placeholder.
bool HasPlaceHolder(const AttrValue& val);
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 0916c9b7a8..71a31b0e75 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -1417,6 +1417,21 @@ Status ExplicitShape(InferenceContext* c) {
return Status::OK();
}
+Status ExplicitShapes(InferenceContext* c) {
+ std::vector<PartialTensorShape> shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
+ if (shapes.empty()) {
+ return errors::Internal("shapes attribute is empty");
+ }
+ for (int i = 0; i < shapes.size(); ++i) {
+ ShapeHandle output_shape;
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape));
+ c->set_output(i, output_shape);
+ }
+ return Status::OK();
+}
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 789746b403..87bb133d92 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -289,6 +289,9 @@ Status ScatterNdUpdateShape(InferenceContext* c);
// Shape function for ops with an explicit "shape" attribute.
Status ExplicitShape(InferenceContext* c);
+// Shape function for multiple-output ops with an explicit "shapes" attribute.
+Status ExplicitShapes(InferenceContext* c);
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 3185875e3b..b02bc3adbe 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -616,8 +616,9 @@ Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
int64 end_in = end;
const int32 rank = Rank(s);
- if (start == 0 && ((RankKnown(s) && end >= rank) ||
- end == std::numeric_limits<int64>::max())) {
+ if (start == 0 && stride == 1 &&
+ ((RankKnown(s) && end >= rank) ||
+ end == std::numeric_limits<int64>::max())) {
*out = s;
return Status::OK();
}
@@ -663,7 +664,6 @@ Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
}
std::vector<DimensionHandle> dims;
- dims.reserve((end - start) / stride);
for (int i = start; stride > 0 ? i < end : i > end; i += stride) {
dims.push_back(Dim(s, i));
}
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index 5e2a465e22..029cdcf94a 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -2022,6 +2022,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
@@ -2051,6 +2052,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);DMT/_0(Const);DMT/_1(Const)|"
@@ -2069,6 +2071,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
@@ -2095,6 +2098,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Input'}"
@@ -2125,6 +2129,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
@@ -2151,6 +2156,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Positive) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C'] }"
"node { name: 'E' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
@@ -2178,6 +2184,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Negative1) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C'] }"
"node { name: 'E' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
@@ -2204,6 +2211,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Negative2) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C'] }"
"node { name: 'E' op: 'BiasAddGrad'"
" attr { key: 'T' value { type: DT_FLOAT } }"
@@ -2233,6 +2241,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackpropFilterFusion_Negative3) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
"node { name: 'E' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
@@ -2272,6 +2281,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_ConvBpropInput_FilterFwd) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'BiasAdd'"
@@ -2289,6 +2299,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_ConvBpropInput_FilterFwd) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['F', 'B', 'E']}"
"node { name: 'Z' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
@@ -2319,6 +2330,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
@@ -2341,6 +2353,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
@@ -2348,6 +2361,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
@@ -2370,6 +2384,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }"
" input: ['B', 'C'] }");
@@ -2389,6 +2404,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
@@ -2411,6 +2427,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['B', 'A', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
@@ -2477,6 +2494,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_BiasAddGrad_Positive2) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
@@ -2529,6 +2547,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'F' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
@@ -2536,6 +2555,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['C', 'D']}"
"node { name: 'G' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
@@ -2572,6 +2592,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D']}"
@@ -2634,6 +2655,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'F' op: 'Conv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
@@ -2641,6 +2663,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['C', 'D']}"
"node { name: 'G' op: 'Const' "
" attr { key: 'dtype' value { type: DT_INT32 } }"
@@ -2678,6 +2701,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D']}"
@@ -3274,6 +3298,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B']}"
"node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }",
@@ -3296,6 +3321,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
"node { name: 'E' op: 'Zeta'"
" attr {key: 'T' value { type: DT_FLOAT } }"
@@ -3323,6 +3349,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
+ " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }",
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index d8c4489a57..69b7594735 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -816,6 +816,60 @@ class SymbolicShapeRefiner {
c->output_tensors_as_shapes.resize(1);
c->output_tensors_as_shapes[0] = result;
}
+ } else if (IsStridedSlice(node)) {
+ ShapeHandle input = ic->input_tensors_as_shapes()[0];
+ bool valid = ic->RankKnown(input);
+ const Tensor* slice_begin = ic->input_tensor(1);
+ valid &= slice_begin != nullptr && slice_begin->NumElements() == 1;
+ const Tensor* slice_end = ic->input_tensor(2);
+ valid &= slice_end != nullptr && slice_end->NumElements() == 1;
+ const Tensor* slice_stride = ic->input_tensor(3);
+ valid &= slice_stride != nullptr && slice_stride->NumElements() == 1;
+
+ if (node.attr().count("ellipsis_mask") > 0 &&
+ node.attr().at("ellipsis_mask").i() != 0) {
+ valid = false;
+ }
+ if (node.attr().count("new_axis_mask") > 0 &&
+ node.attr().at("new_axis_mask").i() != 0) {
+ valid = false;
+ }
+ if (node.attr().count("shrink_axis_mask") > 0 &&
+ node.attr().at("shrink_axis_mask").i() != 0) {
+ valid = false;
+ }
+ int begin_mask = 0;
+ if (node.attr().count("begin_mask") > 0) {
+ begin_mask = node.attr().at("begin_mask").i();
+ }
+ int end_mask = 0;
+ if (node.attr().count("end_mask") > 0) {
+ end_mask = node.attr().at("end_mask").i();
+ }
+ if (begin_mask < 0 || begin_mask > 1 || end_mask < 0 || end_mask > 1) {
+ valid = false;
+ }
+ if (valid) {
+ int64 begin = 0;
+ if (begin_mask == 0) {
+ begin = slice_begin->dtype() == DT_INT32
+ ? slice_begin->flat<int32>()(0)
+ : slice_begin->flat<int64>()(0);
+ }
+ int64 end = std::numeric_limits<int64>::max();
+ if (end_mask == 0) {
+ end =
+ (slice_end->dtype() == DT_INT32 ? slice_end->flat<int32>()(0)
+ : slice_end->flat<int64>()(0));
+ }
+ int64 stride = slice_stride->dtype() == DT_INT32
+ ? slice_stride->flat<int32>()(0)
+ : slice_stride->flat<int64>()(0);
+ ShapeHandle result;
+ TF_RETURN_IF_ERROR(ic->Subshape(input, begin, end, stride, &result));
+ c->output_tensors_as_shapes.resize(1);
+ c->output_tensors_as_shapes[0] = result;
+ }
}
}
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index a53f6414c3..3e44b222fd 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -952,6 +952,39 @@ TEST_F(GraphPropertiesTest, Performance) {
TF_CHECK_OK(properties.InferStatically(false));
}
+TEST_F(GraphPropertiesTest, StridedSlicesOfShapes) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a =
+ ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
+ ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
+ auto shp = ops::Shape(s.WithOpName("shape"), {a});
+
+ Output index1 = ops::Const(s.WithOpName("index1"), 0, {1});
+ Output index2 = ops::Const(s.WithOpName("index2"), 1, {1});
+ Output index3 = ops::Const(s.WithOpName("index3"), 2, {1});
+
+ Output b = ops::StridedSlice(s.WithOpName("b"), shp, index1, index2, index2);
+ Output c = ops::StridedSlice(s.WithOpName("c"), shp, index2, index3, index2);
+
+ Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
+ Output o1 = ops::Fill(s.WithOpName("o1"), b, zero);
+ Output o2 = ops::Fill(s.WithOpName("o2"), c, zero);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
+ const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
+ const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
+ EXPECT_EQ(2, shape_a.dim_size());
+ EXPECT_EQ(1, shape_o1.dim_size());
+ EXPECT_EQ(1, shape_o2.dim_size());
+ EXPECT_EQ(shape_a.dim(0).size(), shape_o1.dim(0).size());
+ EXPECT_EQ(shape_a.dim(1).size(), shape_o2.dim(0).size());
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 2542fa2d67..b8e337582c 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -129,33 +129,6 @@ int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride,
}
}
-// Return a minimum shape if the shape is unknown. If known, return the original
-// shape.
-TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
- int rank, bool* found_unknown_shapes) {
- auto shape = original_shape;
- if (shape.unknown_rank() || shape.dim_size() < rank) {
- *found_unknown_shapes = true;
- TensorShapeProto::Dim dim;
- VLOG(2) << "Use minimum shape because the rank is unknown.";
- // The size of each dimension is at least 1, if unknown.
- dim.set_size(1);
- for (int i = 0; i < rank; i++) {
- *shape.add_dim() = dim;
- }
- } else {
- for (int i = 0; i < shape.dim_size(); i++) {
- if (shape.dim(i).size() < 0) {
- *found_unknown_shapes = true;
- VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
- // The size of each dimension is at least 1, if unknown.
- shape.mutable_dim(i)->set_size(1);
- }
- }
- }
- return shape;
-}
-
// Return the output element count of a binary element-wise op considering
// broadcasting.
int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
@@ -187,6 +160,33 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
} // namespace
+// Return a minimum shape if the shape is unknown. If known, return the original
+// shape.
+TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
+ int rank, bool* found_unknown_shapes) {
+ auto shape = original_shape;
+ if (shape.unknown_rank() || shape.dim_size() < rank) {
+ *found_unknown_shapes = true;
+ TensorShapeProto::Dim dim;
+ VLOG(2) << "Use minimum shape because the rank is unknown.";
+ // The size of each dimension is at least 1, if unknown.
+ dim.set_size(1);
+ for (int i = 0; i < rank; i++) {
+ *shape.add_dim() = dim;
+ }
+ } else {
+ for (int i = 0; i < shape.dim_size(); i++) {
+ if (shape.dim(i).size() < 0) {
+ *found_unknown_shapes = true;
+ VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
+ // The size of each dimension is at least 1, if unknown.
+ shape.mutable_dim(i)->set_size(1);
+ }
+ }
+ }
+ return shape;
+}
+
OpLevelCostEstimator::OpLevelCostEstimator() {
// Syntactic sugar to build and return a lambda that takes an OpInfo and
// returns a cost.
@@ -865,6 +865,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations(
conv_dims.oz *= conv_dims.iz;
ops *= conv_dims.oz;
}
+ ops *= kOpsPerMac;
VLOG(1) << "Operations for" << op_features.op() << " " << ops;
@@ -921,7 +922,7 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
conv_dims.oz *= conv_dims.iz;
ops *= conv_dims.oz;
}
-
+ ops *= kOpsPerMac;
VLOG(1) << "Operations for" << op_features.op() << " " << ops;
if (returned_conv_dims != nullptr) {
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 35649f7ee9..d384f57279 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -30,6 +30,8 @@ namespace grappler {
bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto,
TensorShapeProto* tensor_shape_proto);
+TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
+ int rank, bool* found_unknown_shapes);
class OpLevelCostEstimator {
public:
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index e633ecf789..07f826beed 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -408,6 +408,21 @@ bool IsPersistent(const NodeDef& node) {
return IsConstant(node) || IsVariable(node);
}
+bool MaybeHasRefInput(const NodeDef& node) {
+ const OpDef* op_def;
+ Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
+ if (!status.ok()) {
+ return true;
+ }
+ // Nodes such as Assign or AssignAdd modify one of their inputs.
+ for (const auto& input : op_def->input_arg()) {
+ if (input.is_ref()) {
+ return true;
+ }
+ }
+ return false;
+}
+
bool IsFreeOfSideEffect(const NodeDef& node) {
// Placeholders must be preserved to keep the graph feedable.
if (IsPlaceholder(node)) {
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index f6105d710e..a5599eb22e 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -166,6 +166,10 @@ bool IsPersistent(const NodeDef& node);
bool IsFreeOfSideEffect(const NodeDef& node);
+// Returns true if the takes a tensor reference as input, or if looking up its
+// OpDef failed.
+bool MaybeHasRefInput(const NodeDef& node);
+
bool ModifiesFrameInfo(const NodeDef& node);
// Returns true if the op is known to write to one or more of its inputs.
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 900dfa95c5..e1c2a64da1 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -508,7 +508,6 @@ cc_library(
":arithmetic_optimizer",
":auto_parallel",
":constant_folding",
- ":custom_graph_optimizer",
":custom_graph_optimizer_registry",
":debug_stripper",
":dependency_optimizer",
@@ -518,6 +517,7 @@ cc_library(
":loop_optimizer",
":memory_optimizer",
":model_pruner",
+ ":shape_optimizer",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -630,12 +630,50 @@ tf_cuda_cc_test(
)
cc_library(
+ name = "shape_optimizer",
+ srcs = ["shape_optimizer.cc"],
+ hdrs = [
+ "shape_optimizer.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_optimizer",
+ ":symbolic_shapes",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:frame",
+ ],
+)
+
+tf_cc_test(
+ name = "shape_optimizer_test",
+ srcs = ["shape_optimizer_test.cc"],
+ deps = [
+ ":shape_optimizer",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/utils:grappler_test",
+ ],
+)
+
+cc_library(
name = "symbolic_shapes",
srcs = ["symbolic_shapes.cc"],
hdrs = ["symbolic_shapes.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
] + tf_protos_grappler(),
)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index adfae2e1a3..adef75f63e 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
@@ -38,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/tensor_coding.h"
@@ -254,6 +256,17 @@ NodeDef* GetTailOfValuePreservingChain(
is_value_preserving_non_branching);
}
+NodeDef* GetTailOfIdempotentChain(
+ const NodeDef& node, const NodeMap& node_map,
+ const std::unordered_set<string>& nodes_to_preserve) {
+ auto is_idempotent_non_branching = [&](const NodeDef& node) {
+ return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
+ IsIdempotent(node) && NumNonControlOutputs(node, node_map) == 1;
+ };
+ return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
+ is_idempotent_non_branching);
+}
+
// Graph optimizer context extension specific to ArithmeticOptimizer.
struct ArithmeticOptimizerContext {
explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
@@ -270,7 +283,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
const ArithmeticOptimizerContext ctx_ext)
: GraphOptimizerStage("ArithmeticOptimizer", name, ctx),
ctx_ext_(ctx_ext) {}
- virtual ~ArithmeticOptimizerStage() = default;
+ ~ArithmeticOptimizerStage() override = default;
protected:
// Simplification graph rewrite can create additional nodes that are inputs
@@ -1149,21 +1162,27 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
public:
explicit RemoveIdentityTranspose(const GraphOptimizerContext& ctx,
- const ArithmeticOptimizerContext& ctx_ext)
- : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext) {}
+ const ArithmeticOptimizerContext& ctx_ext,
+ RewriterConfig::Toggle opt_level)
+ : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext),
+ opt_level_(opt_level) {}
~RemoveIdentityTranspose() override = default;
bool IsSupported(const NodeDef* node) const override {
return IsTranspose(*node) || IsConjugateTranspose(*node);
}
- // TODO(rmlarsen): Forward control dependencies on the bypassed
- // transpose nodes.
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
+ NodeDef* tail = node;
+ // TODO(rmlarsen): Enable in regular mode after May 15, 2018.
+ if (opt_level_ == RewriterConfig::AGGRESSIVE) {
+ tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
+ *ctx().nodes_to_preserve);
+ }
+ NodeDef* first_transpose;
+ TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
- NodeDef* input;
- TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
NodeDef* node_perm;
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
if (!IsConstant(*node_perm)) {
@@ -1171,17 +1190,30 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
}
std::vector<int64> node_perm_values;
TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
- if (input->op() == node->op()) {
+ if (first_transpose->op() == node->op()) {
// Remove pairs of transposes that cancel each other.
- NodeDef* input_perm;
- TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &input_perm));
- if (!IsConstant(*input_perm)) {
+ NodeDef* first_transpose_perm;
+ TF_RETURN_IF_ERROR(
+ GetInputNode(first_transpose->input(1), &first_transpose_perm));
+ if (!IsConstant(*first_transpose_perm)) {
return Status::OK();
}
- std::vector<int64> input_perm_values;
- TF_RETURN_IF_ERROR(GetPermutation(*input_perm, &input_perm_values));
- if (AreInversePermutations(node_perm_values, input_perm_values)) {
- *simplified_node_name = input->input(0);
+ std::vector<int64> first_transpose_perm_values;
+ TF_RETURN_IF_ERROR(
+ GetPermutation(*first_transpose_perm, &first_transpose_perm_values));
+ if (AreInversePermutations(node_perm_values,
+ first_transpose_perm_values)) {
+ if (tail == node) {
+ // Bypass adjacent pair.
+ *simplified_node_name = first_transpose->input(0);
+ } else {
+ // Bypass pair connected through chain.
+ tail->set_input(0, first_transpose->input(0));
+ ctx().node_map->UpdateInput(tail->name(), first_transpose->name(),
+ first_transpose->input(0));
+ ForwardControlDependencies(tail, {first_transpose});
+ *simplified_node_name = node->input(0);
+ }
}
} else {
// Remove simple identity transposes.
@@ -1231,6 +1263,8 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
}
return true;
}
+
+ RewriterConfig::Toggle opt_level_;
};
// Remove redundant Bitcasts.
@@ -1752,7 +1786,7 @@ class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
class UniqueNodes {
public:
NodeDef* FindOrAddRepresentative(NodeDef* node) {
- std::size_t sig = ComputeSignature(*node);
+ uint64 sig = ComputeSignature(*node);
std::vector<NodeDef*>& candidates = rep_[sig];
for (auto& candidate : candidates) {
if (SameNode(*candidate, *node)) {
@@ -1764,26 +1798,25 @@ class UniqueNodes {
}
private:
- std::size_t ComputeSignature(const NodeDef& node) const;
+ uint64 ComputeSignature(const NodeDef& node) const;
bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
- std::unordered_map<std::size_t, std::vector<NodeDef*>> rep_;
+ std::unordered_map<uint64, std::vector<NodeDef*>> rep_;
};
-std::size_t UniqueNodes::ComputeSignature(const NodeDef& node) const {
- std::size_t h = std::hash<string>{}(node.op());
- h ^= std::hash<string>{}(node.device());
+uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const {
+ uint64 h = Hash64(node.op());
+ h = Hash64Combine(Hash64(node.device()), h);
+
for (const auto& input : node.input()) {
int pos;
string node_name = ParseNodeName(input, &pos);
- h ^= std::hash<string>{}(node_name);
- h ^= static_cast<std::size_t>(pos);
+ h = Hash64CombineUnordered(Hash64(node_name), h);
+ h = Hash64CombineUnordered(std::hash<int>()(pos), h);
}
for (const auto& attr : node.attr()) {
- h ^= std::hash<string>{}(attr.first);
- string tmp;
- attr.second.AppendToString(&tmp);
- h ^= std::hash<string>{}(tmp);
+ h = Hash64CombineUnordered(Hash64(attr.first), h);
+ h = Hash64CombineUnordered(FastAttrValueHash(attr.second), h);
}
return h;
}
@@ -1839,17 +1872,8 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
}
for (const auto& attr1 : node1.attr()) {
auto it = node2.attr().find(attr1.first);
- if (it == node2.attr().end()) {
- return false;
- }
- const auto& attr2 = *it;
- string val1;
- attr1.second.AppendToString(&val1);
- string val2;
- attr2.second.AppendToString(&val2);
- if (val1 != val2) {
- return false;
- }
+ if (it == node2.attr().end()) return false;
+ if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false;
}
return true;
@@ -2233,6 +2257,9 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
new_square_node->set_input(i - 1, new_square_node->input(i));
}
new_square_node->mutable_input()->RemoveLast();
+ for (const string& input : new_square_node->input()) {
+ node_map_->AddOutput(NodeName(input), new_square_node->name());
+ }
return new_square_node->name();
}
}
@@ -2398,7 +2425,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.minimize_broadcasts && can_use_shapes)
pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
if (options_.remove_identity_transpose && can_use_shapes)
- pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
+ pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext, opt_level_);
if (options_.remove_redundant_bitcast)
pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
if (options_.remove_redundant_cast)
@@ -2491,7 +2518,8 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
graph_properties_.reset(new GraphProperties(optimized_item));
- const Status status = graph_properties_->InferStatically(false);
+ const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
+ const Status status = graph_properties_->InferStatically(assume_valid_feeds);
const bool can_use_shapes = status.ok();
if (!can_use_shapes) {
VLOG(1) << "Shape inference failed." << status.error_message();
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 067adb359c..27c0dde419 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -964,6 +964,67 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
+TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output inputs =
+ ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
+ Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
+ Output reshape = ops::Reshape(s, inputs, target_shape);
+ Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
+
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ item.feed = {{"Placeholder", x_t}};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
+
+ item.graph.Swap(&output);
+ TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+
+ // The reshape is preserved because the shape of the placeholder can be
+ // different from the shape of the actual feed.
+ EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
+
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
+TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output inputs =
+ ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
+ Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
+ Output reshape = ops::Reshape(s, inputs, target_shape);
+ Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
+
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ item.feed = {{"Placeholder", x_t}};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+ GraphDef output;
+ TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE)
+ .Optimize(nullptr, item, &output));
+
+ item.graph.Swap(&output);
+ TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
// Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can
// be from [4,3,28,28] to [8,6,28,28].
@@ -1122,7 +1183,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposes) {
ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
- Output perm3 = ops::Const(s.WithOpName("perm2"), {0, 1, 2, 3}, {4});
+ Output perm3 = ops::Const(s.WithOpName("perm3"), {0, 1, 2, 3}, {4});
Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1);
Output transpose2 =
ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2);
@@ -1248,6 +1309,47 @@ TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
EXPECT_EQ(6, output.node_size());
}
+TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output inputs_shape =
+ ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
+ Output inputs =
+ ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
+ Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
+ Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
+ Output transpose1 = ops::Transpose(
+ s.WithOpName("transpose1").WithControlDependencies(perm2), inputs, perm1);
+ Output identity = ops::Identity(s.WithOpName("id"), transpose1);
+ Output transpose2 =
+ ops::Transpose(s.WithOpName("transpose2"), identity, perm2);
+ Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
+
+ GrapplerItem item;
+ item.fetch = {"id1"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ EnableOnlyRemoveIdentityTranspose(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ std::set<string> nodes_after_optimization;
+ for (const NodeDef& node : output.node()) {
+ nodes_after_optimization.insert(node.name());
+ if (node.name() == "id") {
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("inputs", node.input(0));
+ EXPECT_EQ("^perm2", node.input(1));
+ }
+ if (node.name() == "id1") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("id", node.input(0));
+ }
+ }
+ EXPECT_EQ(nodes_after_optimization,
+ std::set<string>({"id", "id1", "inputs_shape", "inputs", "perm2"}));
+}
+
TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
@@ -1574,6 +1676,14 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ std::vector<std::pair<string, Tensor>> feed = {
+ {"a", a_t}, {"b", b_t}, {"c", c_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyAddToAddNCombining(&optimizer);
@@ -1607,6 +1717,10 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {
ASSERT_NE(updated_outputs, nullptr);
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
+
+ auto tensors = EvaluateNodes(output, item.fetch, feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
@@ -1631,6 +1745,17 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ std::vector<std::pair<string, Tensor>> feed = {
+ {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyAddToAddNCombining(&optimizer);
@@ -1680,6 +1805,10 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
EXPECT_EQ(2, updated_mul->input_size());
EXPECT_EQ(collapsed_left->name(), updated_mul->input(0));
EXPECT_EQ(collapsed_right->name(), updated_mul->input(1));
+
+ auto tensors = EvaluateNodes(output, item.fetch, feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddInputMultipleTimes) {
@@ -1697,6 +1826,14 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddInputMultipleTimes) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ std::vector<std::pair<string, Tensor>> feed = {
+ {"a", a_t}, {"b", b_t}, {"c", c_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyAddToAddNCombining(&optimizer);
@@ -1725,6 +1862,10 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddInputMultipleTimes) {
EXPECT_EQ("b", collapsed_add->input(1));
EXPECT_EQ("b", collapsed_add->input(2));
EXPECT_EQ("c", collapsed_add->input(3));
+
+ auto tensors = EvaluateNodes(output, item.fetch, feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) {
@@ -1748,6 +1889,11 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ std::vector<std::pair<string, Tensor>> feed = {{"input", x_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyAddToAddNCombining(&optimizer);
@@ -1779,6 +1925,10 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) {
const NodeDef* updated_outputs = node_map.GetNode("outputs");
ASSERT_NE(updated_outputs, nullptr);
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
+
+ auto tensors = EvaluateNodes(output, item.fetch, feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCast) {
@@ -1803,6 +1953,17 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCast) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
+ auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
+ auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
+ auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
+ auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
+ std::vector<std::pair<string, Tensor>> feed = {
+ {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyAddToAddNCombining(&optimizer);
@@ -1875,18 +2036,22 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCast) {
const NodeDef* updated_outputs = node_map.GetNode("outputs");
ASSERT_NE(updated_outputs, nullptr);
EXPECT_EQ(outer_add_name, updated_outputs->input(0));
+
+ auto tensors = EvaluateNodes(output, item.fetch, feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCastWithSymbolicShapes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
// We have a small input with one unknown dimension
- auto small = ops::Variable(s.WithOpName("small"), {-1, 1, 1}, DT_FLOAT);
+ auto small = ops::Variable(s.WithOpName("small"), {-1, 1, 1}, DT_DOUBLE);
// And second input which is larger, but has the same unknown dimension
// device spec prevents this node from rewriting
- auto d = "/job:do_not_rewrite_me";
- auto v = ops::Variable(s.WithOpName("v"), {1, 32, 32}, DT_FLOAT);
+ auto d = "/device:CPU:0";
+ auto v = ops::Variable(s.WithOpName("v"), {1, 32, 32}, DT_DOUBLE);
auto large = ops::Add(s.WithOpName("large").WithDevice(d), small, v);
// [a, c] have {?, 1, 1} shape, [b] has {?, 32, 32}
@@ -1904,6 +2069,12 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCastWithSymbolicShapes) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto s_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({8, 1, 1}));
+ auto v_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({1, 32, 32}));
+ std::vector<std::pair<string, Tensor>> feed = {{"small", s_t}, {"v", v_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyAddToAddNCombining(&optimizer);
@@ -1942,6 +2113,10 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MinimizeBCastWithSymbolicShapes) {
const NodeDef* updated_outputs = node_map.GetNode("outputs");
ASSERT_NE(updated_outputs, nullptr);
EXPECT_EQ(outer_add_name, updated_outputs->input(0));
+
+ auto tensors = EvaluateNodes(output, item.fetch, feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
@@ -1966,6 +2141,12 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
item.fetch = {"add_all"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ std::vector<std::pair<string, Tensor>> feed = {{"x", x_t}, {"y", y_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveNegation(&optimizer);
@@ -2014,6 +2195,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
}
}
EXPECT_EQ(5, found);
+
+ auto tensors = EvaluateNodes(output, item.fetch, feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
@@ -2069,6 +2254,14 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
+ auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
+ auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
+ std::vector<std::pair<string, Tensor>> feed = {
+ {"a", a_t}, {"b", b_t}, {"c", c_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyMinimizeBroadcasts(&optimizer);
@@ -2093,16 +2286,20 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
ASSERT_NE(mul2_node, nullptr);
EXPECT_EQ("mul1", mul2_node->input(0));
EXPECT_EQ("b", mul2_node->input(1));
+
+ auto tensors = EvaluateNodes(output, item.fetch, feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_FlattenTallGraph) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
- auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
- auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
- auto d = ops::Variable(s.WithOpName("d"), {32}, DT_FLOAT);
- auto e = ops::Variable(s.WithOpName("e"), {32}, DT_FLOAT);
+ auto a = ops::Variable(s.WithOpName("a"), {32}, DT_DOUBLE);
+ auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_DOUBLE);
+ auto c = ops::Variable(s.WithOpName("c"), {32}, DT_DOUBLE);
+ auto d = ops::Variable(s.WithOpName("d"), {32}, DT_DOUBLE);
+ auto e = ops::Variable(s.WithOpName("e"), {32}, DT_DOUBLE);
auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
@@ -2115,6 +2312,16 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_FlattenTallGraph) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto a_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
+ auto b_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32, 32}));
+ auto c_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
+ auto d_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
+ auto e_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
+ std::vector<std::pair<string, Tensor>> feed = {
+ {"a", a_t}, {"b", b_t}, {"c", c_t}, {"d", d_t}, {"e", e_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyMinimizeBroadcasts(&optimizer);
@@ -2154,6 +2361,10 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_FlattenTallGraph) {
ASSERT_NE(mul4_node, nullptr);
EXPECT_EQ("mul3", mul4_node->input(0));
EXPECT_EQ("b", mul4_node->input(1));
+
+ auto tensors = EvaluateNodes(output, item.fetch, feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
@@ -2175,6 +2386,15 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
+ auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
+ auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
+ auto d_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
+ std::vector<std::pair<string, Tensor>> feed = {
+ {"a", a_t}, {"b", b_t}, {"c", c_t}, {"D", d_t}};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyMinimizeBroadcasts(&optimizer);
@@ -2206,6 +2426,10 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
ASSERT_NE(mul3_node, nullptr);
EXPECT_EQ("D", mul3_node->input(0));
EXPECT_EQ("mul1", mul3_node->input(1));
+
+ auto tensors = EvaluateNodes(output, item.fetch, feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index e6a74dbdcd..b2dcbf9df5 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1514,6 +1514,16 @@ void ConstantFolding::ReplaceOperationWithIdentity(
void ConstantFolding::ReplaceOperationWithSnapshot(
int input_to_forward, const GraphProperties& properties, NodeDef* node,
GraphDef* graph) {
+ // If the graph contains no ops that mutate their inputs, we can
+ // use Identity insted of Snapshot.
+
+ // TODO(rmlarsen): Enable in regular mode after May 15, 2018.
+ if (opt_level_ == RewriterConfig::AGGRESSIVE &&
+ !graph_contains_assign_or_inplace_op_) {
+ ReplaceOperationWithIdentity(input_to_forward, properties, node, graph);
+ return;
+ }
+
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) return;
@@ -1587,796 +1597,811 @@ Status ConstantFolding::ReplaceOperationWithConstant(
Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
GraphProperties* properties,
bool use_shape_info) {
- const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
for (int i = 0; i < optimized_graph->node_size(); ++i) {
- NodeDef* node = optimized_graph->mutable_node(i);
+ TF_RETURN_IF_ERROR(SimplifyNode(optimized_graph->mutable_node(i),
+ optimized_graph, properties,
+ use_shape_info));
+ }
+ return Status::OK();
+}
- if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
- ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
- continue;
- }
+Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
+ GraphProperties* properties,
+ bool use_shape_info) {
+ const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
+ if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
+ ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
+ return Status::OK();
+ }
- if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- continue;
- }
+ if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
+ ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ return Status::OK();
+ }
- // Remove Shuffle or Transpose op over dimensions of size 1.
- if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
- properties->GetInputProperties(node->name()).size() >= 2) {
- const auto& shape =
- properties->GetInputProperties(node->name())[0].shape();
- if (shape.unknown_rank()) {
- // Not optimizable.
- continue;
- }
- const auto& p = properties->GetInputProperties(node->name())[1];
- if (TensorShape::IsValid(p.shape()) && p.has_value()) {
- Tensor perm(p.dtype(), p.shape());
- if (!perm.FromProto(p.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- p.value().DebugString());
- }
- std::vector<int> permutation;
- for (int j = 0; j < perm.NumElements(); ++j) {
- if (perm.dtype() == DT_INT64) {
- permutation.push_back(perm.vec<int64>()(j));
- } else {
- permutation.push_back(perm.vec<int>()(j));
- }
- }
- if (permutation.size() != shape.dim_size()) {
- // Number of elements in perm should be same as dim_size. Skip if not.
- continue;
- }
- // The node is replaceable iff
- // dim_size == 0 || all dims have size 1 ||
- // all dims with > 1 size are not permuted.
- bool replaceable = true;
- for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
- replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
- }
- if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- continue;
+ // Remove Shuffle or Transpose op over dimensions of size 1.
+ if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
+ properties->GetInputProperties(node->name()).size() >= 2) {
+ const auto& shape = properties->GetInputProperties(node->name())[0].shape();
+ if (shape.unknown_rank()) {
+ // Not optimizable.
+ return Status::OK();
+ }
+ const auto& p = properties->GetInputProperties(node->name())[1];
+ if (TensorShape::IsValid(p.shape()) && p.has_value()) {
+ Tensor perm(p.dtype(), p.shape());
+ if (!perm.FromProto(p.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ p.value().DebugString());
+ }
+ std::vector<int> permutation;
+ for (int j = 0; j < perm.NumElements(); ++j) {
+ if (perm.dtype() == DT_INT64) {
+ permutation.push_back(perm.vec<int64>()(j));
+ } else {
+ permutation.push_back(perm.vec<int>()(j));
}
}
- }
-
- // Remove RandomShuffle op if it is scalar or first dimension is of size 1.
- if (use_shape_info && IsRandomShuffle(*node) &&
- !properties->GetInputProperties(node->name()).empty()) {
- const auto& shape =
- properties->GetInputProperties(node->name())[0].shape();
+ if (permutation.size() != shape.dim_size()) {
+ // Number of elements in perm should be same as dim_size. Skip if not.
+ return Status::OK();
+ }
// The node is replaceable iff
- // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
- if (!shape.unknown_rank() &&
- (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
+ // dim_size == 0 || all dims have size 1 ||
+ // all dims with > 1 size are not permuted.
+ bool replaceable = true;
+ for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+ replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
+ }
+ if (replaceable) {
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- continue;
+ return Status::OK();
}
}
+ }
- // Remove Reverse op over dimensions with size 1.
- if (use_shape_info && node->op() == "ReverseV2" &&
- properties->GetInputProperties(node->name()).size() >= 2) {
- const auto& shape =
- properties->GetInputProperties(node->name())[0].shape();
- if (shape.unknown_rank()) {
- // Not optimizable.
- continue;
- }
- const auto& a = properties->GetInputProperties(node->name())[1];
- if (TensorShape::IsValid(a.shape()) && a.has_value()) {
- Tensor axis(a.dtype(), a.shape());
- if (!axis.FromProto(a.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- a.value().DebugString());
- }
- std::set<int> target_axes;
- for (int j = 0; j < axis.NumElements(); ++j) {
- // value of axis can be negative.
- if (axis.dtype() == DT_INT64) {
- target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
- shape.dim_size());
- } else {
- target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
- shape.dim_size());
- }
- }
-
- // The node is replaceable iff
- // unknown_rank == false &&
- // (dim_size == 0 || all dims have size 1 ||
- // all dims with > 1 size are not in target_axes)
- bool replaceable = !shape.unknown_rank();
- for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
- replaceable &= shape.dim(j).size() == 1 ||
- target_axes.find(j) == target_axes.end();
- }
- if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- continue;
- }
- }
+ // Remove RandomShuffle op if it is scalar or first dimension is of size 1.
+ if (use_shape_info && IsRandomShuffle(*node) &&
+ !properties->GetInputProperties(node->name()).empty()) {
+ const auto& shape = properties->GetInputProperties(node->name())[0].shape();
+ // The node is replaceable iff
+ // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
+ if (!shape.unknown_rank() &&
+ (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
+ ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ return Status::OK();
}
+ }
- if (use_shape_info && IsSlice(*node) &&
- properties->GetInputProperties(node->name()).size() == 3) {
- const auto& input = properties->GetInputProperties(node->name())[0];
- const auto& b = properties->GetInputProperties(node->name())[1];
- const auto& s = properties->GetInputProperties(node->name())[2];
- if (TensorShape::IsValid(b.shape()) && b.has_value() &&
- TensorShape::IsValid(s.shape()) && s.has_value()) {
- Tensor begin(b.dtype(), b.shape());
- if (!begin.FromProto(b.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- b.value().DebugString());
- }
- Tensor size(s.dtype(), s.shape());
- if (!size.FromProto(s.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- s.value().DebugString());
- }
- // The node is replaceable iff unknown_rank == false &&
- // begin == 0 && (size == -1 || size == input_shape) for all dimensions
- bool replaceable = !input.shape().unknown_rank();
- for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
- if (begin.dtype() == DT_INT32) {
- replaceable &= begin.vec<int>()(j) == 0;
- } else {
- replaceable &= begin.vec<int64>()(j) == 0;
- }
- if (size.dtype() == DT_INT32) {
- replaceable &= (size.vec<int>()(j) == -1 ||
- size.vec<int>()(j) == input.shape().dim(j).size());
- } else {
- replaceable &=
- (size.vec<int64>()(j) == -1 ||
- size.vec<int64>()(j) == input.shape().dim(j).size());
- }
- }
- if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- continue;
- }
- }
+ // Remove Reverse op over dimensions with size 1.
+ if (use_shape_info && node->op() == "ReverseV2" &&
+ properties->GetInputProperties(node->name()).size() >= 2) {
+ const auto& shape = properties->GetInputProperties(node->name())[0].shape();
+ if (shape.unknown_rank()) {
+ // Not optimizable.
+ return Status::OK();
}
-
- if (use_shape_info && IsTile(*node) &&
- properties->GetInputProperties(node->name()).size() == 2) {
- const auto& m = properties->GetInputProperties(node->name())[1];
- if (TensorShape::IsValid(m.shape()) && m.has_value()) {
- Tensor multiplies(m.dtype(), m.shape());
- if (!multiplies.FromProto(m.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- m.value().DebugString());
- }
- // The node is replaceable iff all values in multiplies are 1.
- bool replaceable = true;
- if (multiplies.dtype() == DT_INT32) {
- for (int j = 0; replaceable && j < multiplies.vec<int>().size();
- ++j) {
- replaceable &= multiplies.vec<int>()(j) == 1;
- }
+ const auto& a = properties->GetInputProperties(node->name())[1];
+ if (TensorShape::IsValid(a.shape()) && a.has_value()) {
+ Tensor axis(a.dtype(), a.shape());
+ if (!axis.FromProto(a.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ a.value().DebugString());
+ }
+ std::set<int> target_axes;
+ for (int j = 0; j < axis.NumElements(); ++j) {
+ // value of axis can be negative.
+ if (axis.dtype() == DT_INT64) {
+ target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
+ shape.dim_size());
} else {
- for (int j = 0; replaceable && j < multiplies.vec<int64>().size();
- ++j) {
- replaceable &= multiplies.vec<int64>()(j) == 1;
- }
- }
- if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- continue;
+ target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
+ shape.dim_size());
}
}
- }
- if (use_shape_info && IsPad(*node) &&
- properties->GetInputProperties(node->name()).size() >= 2) {
- const auto& p = properties->GetInputProperties(node->name())[1];
- if (TensorShape::IsValid(p.shape()) && p.has_value()) {
- Tensor paddings(p.dtype(), p.shape());
- if (!paddings.FromProto(p.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- p.value().DebugString());
- }
- // The node is replaceable iff all values in paddings are 0.
- bool replaceable = true;
- // The operation requires it to be int32 value so we don't check for
- // 1nt64.
- const auto flatten = paddings.flat<int32>();
- for (int j = 0; replaceable && j < flatten.size(); ++j) {
- replaceable &= flatten(j) == 0;
- }
- if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- continue;
- }
- }
- }
-
- if (use_shape_info && IsSqueeze(*node) &&
- !properties->GetInputProperties(node->name()).empty()) {
- // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
- // error to squeeze a dimension that is not 1, so we only need to check
- // whether the input has > 1 size for each dimension.
- const auto& shape =
- properties->GetInputProperties(node->name())[0].shape();
// The node is replaceable iff
- // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
+ // unknown_rank == false &&
+ // (dim_size == 0 || all dims have size 1 ||
+ // all dims with > 1 size are not in target_axes)
bool replaceable = !shape.unknown_rank();
for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
- replaceable &= shape.dim(j).size() > 1;
+ replaceable &= shape.dim(j).size() == 1 ||
+ target_axes.find(j) == target_axes.end();
}
if (replaceable) {
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- continue;
+ return Status::OK();
}
}
+ }
- if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
- !OptimizedNodeExists(*node, "_const_axis")) {
- // Create constant axis node.
- Tensor axis_t(DT_INT32, TensorShape({}));
- NodeDef* axis_node = optimized_graph->add_node();
- axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
- const int axis = node->attr().at("axis").i();
- if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
- !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
- .ok()) {
- continue;
+ if (use_shape_info && IsSlice(*node) &&
+ properties->GetInputProperties(node->name()).size() == 3) {
+ const auto& input = properties->GetInputProperties(node->name())[0];
+ const auto& b = properties->GetInputProperties(node->name())[1];
+ const auto& s = properties->GetInputProperties(node->name())[2];
+ if (TensorShape::IsValid(b.shape()) && b.has_value() &&
+ TensorShape::IsValid(s.shape()) && s.has_value()) {
+ Tensor begin(b.dtype(), b.shape());
+ if (!begin.FromProto(b.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ b.value().DebugString());
}
- // Add a control dependency to make sure axis_node is in the right frame.
- const string ctrl_dep = ConstantFolding::AddControlDependency(
- node->input(0), graph_, node_map_.get());
- axis_node->add_input(ctrl_dep);
- axis_node->set_device(node->device());
- node->set_op("ExpandDims");
- if (node->attr().count("axis") != 0) {
- node->mutable_attr()->erase("axis");
+ Tensor size(s.dtype(), s.shape());
+ if (!size.FromProto(s.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ s.value().DebugString());
}
- if (node->attr().count("N") != 0) {
- node->mutable_attr()->erase("N");
+ // The node is replaceable iff unknown_rank == false &&
+ // begin == 0 && (size == -1 || size == input_shape) for all dimensions
+ bool replaceable = !input.shape().unknown_rank();
+ for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
+ if (begin.dtype() == DT_INT32) {
+ replaceable &= begin.vec<int>()(j) == 0;
+ } else {
+ replaceable &= begin.vec<int64>()(j) == 0;
+ }
+ if (size.dtype() == DT_INT32) {
+ replaceable &= (size.vec<int>()(j) == -1 ||
+ size.vec<int>()(j) == input.shape().dim(j).size());
+ } else {
+ replaceable &= (size.vec<int64>()(j) == -1 ||
+ size.vec<int64>()(j) == input.shape().dim(j).size());
+ }
}
- (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
- node->add_input(axis_node->name());
- if (node->input_size() > 2) {
- node->mutable_input()->SwapElements(1, node->input_size() - 1);
+ if (replaceable) {
+ ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ return Status::OK();
}
- graph_modified_ = true;
- continue;
}
+ }
- // Move constants past Enter.
- if (IsEnter(*node) && node->input_size() > 0) {
- if (node->attr().count("is_constant") == 0 ||
- !node->attr().at("is_constant").b()) {
- continue;
+ if (use_shape_info && IsTile(*node) &&
+ properties->GetInputProperties(node->name()).size() == 2) {
+ const auto& m = properties->GetInputProperties(node->name())[1];
+ if (TensorShape::IsValid(m.shape()) && m.has_value()) {
+ Tensor multiplies(m.dtype(), m.shape());
+ if (!multiplies.FromProto(m.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ m.value().DebugString());
}
- const string& node_name = node->name();
- const NodeDef* input = node_map_->GetNode(node->input(0));
- if (input != nullptr && IsReallyConstant(*input) &&
- !OptimizedNodeExists(*input, "_enter")) {
- auto fanouts = node_map_->GetOutputs(node_name);
- // Find non-constant nodes that consume the output of *node.
- std::vector<NodeDef*> consumers;
- for (NodeDef* fanout : fanouts) {
- if (!IsConstant(*fanout)) {
- for (int i = 0; i < fanout->input_size(); ++i) {
- if (fanout->input(i) == node_name) {
- consumers.push_back(fanout);
- break;
- }
- }
- }
+ // The node is replaceable iff all values in multiplies are 1.
+ bool replaceable = true;
+ if (multiplies.dtype() == DT_INT32) {
+ for (int j = 0; replaceable && j < multiplies.vec<int>().size(); ++j) {
+ replaceable &= multiplies.vec<int>()(j) == 1;
}
- if (!consumers.empty()) {
- NodeDef* new_node = optimized_graph->add_node();
- *new_node = *input;
- new_node->set_name(OptimizedNodeName(*input, "_enter"));
- new_node->set_device(node->device());
- new_node->clear_input();
- new_node->add_input(AsControlDependency(node_name));
- node_map_->AddNode(new_node->name(), new_node);
- node_map_->AddOutput(node_name, new_node->name());
- for (NodeDef* consumer : consumers) {
- for (int i = 0; i < consumer->input_size(); ++i) {
- if (NodeName(consumer->input(i)) == node_name) {
- node_map_->UpdateInput(consumer->name(), node_name,
- new_node->name());
- consumer->set_input(i, new_node->name());
- }
- }
- }
- graph_modified_ = true;
- continue;
+ } else {
+ for (int j = 0; replaceable && j < multiplies.vec<int64>().size();
+ ++j) {
+ replaceable &= multiplies.vec<int64>()(j) == 1;
}
}
+ if (replaceable) {
+ ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ return Status::OK();
+ }
}
+ }
- // Switch(x, x) will always feed false to its false branch and true to
- // its true branch. By rewriting the graph a bit, we can propagate these
- // constants down the two output branches, and just use control dependencies
- // to trigger the selected one at runtime. For example,
- //
- // +------+
- // x-->|Switch|-->a (in practice there may be multiple consumers of each
- // x-->| |-->b output branch.)
- // +------+
- //
- // Is rewritten as
- //
- // +------+
- // x-->|Switch|-->Identity--^>Const(false)-->a
- // x-->| |-->Identity--^>Const(true)-->b
- // +------+
- if (node->op() == "Switch" && node->input(0) == node->input(1) &&
- !OptimizedNodeExists(*node, "_const_false") &&
- !OptimizedNodeExists(*node, "_const_true")) {
- bool already_optimized = true;
- // If the optimization was already applied, the switch would have exactly
- // one Identity node consuming each of its outputs, each without any
- // non-control outputs.
- auto fanouts = node_map_->GetOutputs(node->name());
- if (fanouts.size() == 2) {
- for (NodeDef* fanout : fanouts) {
- if (!IsIdentity(*fanout) ||
- NumNonControlOutputs(*fanout, *node_map_) > 0) {
- already_optimized = false;
- break;
- }
- }
+ if (use_shape_info && IsPad(*node) &&
+ properties->GetInputProperties(node->name()).size() >= 2) {
+ const auto& p = properties->GetInputProperties(node->name())[1];
+ if (TensorShape::IsValid(p.shape()) && p.has_value()) {
+ Tensor paddings(p.dtype(), p.shape());
+ if (!paddings.FromProto(p.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ p.value().DebugString());
}
- Tensor false_t(DT_BOOL, TensorShape({}));
- Tensor true_t(DT_BOOL, TensorShape({}));
- // Make sure we don't proceed if this switch node was already optimized.
- if (!already_optimized && SetTensorValue(DT_BOOL, true, &true_t).ok() &&
- SetTensorValue(DT_BOOL, false, &false_t).ok()) {
- // Copy the set of consumers of the switch as they will be manipulated
- // below.
- const std::set<NodeDef*>& consumer_set =
- node_map_->GetOutputs(node->name());
- std::vector<NodeDef*> consumers(consumer_set.begin(),
- consumer_set.end());
- std::sort(consumers.begin(), consumers.end(),
- [](const NodeDef* n1, const NodeDef* n2) {
- return n1->name() < n2->name();
- });
- // Create constant false & true nodes.
- NodeDef* false_node = optimized_graph->add_node();
- false_node->set_name(OptimizedNodeName(*node, "_const_false"));
- if (!CreateNodeDef(false_node->name(), TensorValue(&false_t),
- false_node)
- .ok()) {
- continue;
- }
- false_node->set_device(node->device());
+ // The node is replaceable iff all values in paddings are 0.
+ bool replaceable = true;
+ // The operation requires it to be int32 value so we don't check for
+ // 1nt64.
+ const auto flatten = paddings.flat<int32>();
+ for (int j = 0; replaceable && j < flatten.size(); ++j) {
+ replaceable &= flatten(j) == 0;
+ }
+ if (replaceable) {
+ ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ return Status::OK();
+ }
+ }
+ }
- NodeDef* true_node = optimized_graph->add_node();
- true_node->set_name(OptimizedNodeName(*node, "_const_true"));
- if (!CreateNodeDef(true_node->name(), TensorValue(&true_t), true_node)
- .ok()) {
- continue;
- }
- true_node->set_device(node->device());
-
- // Add controls from the switch ports to the constants, and connect the
- // constants to the original switch outputs.
- const string false_port = node->name();
- const string true_port = strings::StrCat(node->name(), ":1");
- const string false_ctrl_dep =
- AddControlDependency(false_port, optimized_graph, node_map_.get());
- false_node->add_input(false_ctrl_dep);
- const string true_ctrl_dep =
- AddControlDependency(true_port, optimized_graph, node_map_.get());
- true_node->add_input(true_ctrl_dep);
-
- node_map_->AddNode(false_node->name(), false_node);
- node_map_->AddNode(true_node->name(), true_node);
- node_map_->AddOutput(NodeName(false_ctrl_dep), false_node->name());
- node_map_->AddOutput(NodeName(true_ctrl_dep), true_node->name());
+ if (use_shape_info && IsSqueeze(*node) &&
+ !properties->GetInputProperties(node->name()).empty()) {
+ // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
+ // error to squeeze a dimension that is not 1, so we only need to check
+ // whether the input has > 1 size for each dimension.
+ const auto& shape = properties->GetInputProperties(node->name())[0].shape();
+ // The node is replaceable iff
+ // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
+ bool replaceable = !shape.unknown_rank();
+ for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+ replaceable &= shape.dim(j).size() > 1;
+ }
+ if (replaceable) {
+ ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ return Status::OK();
+ }
+ }
+ if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
+ !OptimizedNodeExists(*node, "_const_axis")) {
+ // Create constant axis node.
+ Tensor axis_t(DT_INT32, TensorShape({}));
+ NodeDef* axis_node = optimized_graph->add_node();
+ axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
+ const int axis = node->attr().at("axis").i();
+ if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
+ !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
+ .ok()) {
+ return Status::OK();
+ }
+ // Add a control dependency to make sure axis_node is in the right frame.
+ const string ctrl_dep = ConstantFolding::AddControlDependency(
+ node->input(0), graph_, node_map_.get());
+ axis_node->add_input(ctrl_dep);
+ axis_node->set_device(node->device());
+ node->set_op("ExpandDims");
+ if (node->attr().count("axis") != 0) {
+ node->mutable_attr()->erase("axis");
+ }
+ if (node->attr().count("N") != 0) {
+ node->mutable_attr()->erase("N");
+ }
+ (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
+ node->add_input(axis_node->name());
+ if (node->input_size() > 2) {
+ node->mutable_input()->SwapElements(1, node->input_size() - 1);
+ }
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ // Move constants past Enter.
+ if (IsEnter(*node) && node->input_size() > 0) {
+ if (node->attr().count("is_constant") == 0 ||
+ !node->attr().at("is_constant").b()) {
+ return Status::OK();
+ }
+ const string& node_name = node->name();
+ const NodeDef* input = node_map_->GetNode(node->input(0));
+ if (input != nullptr && IsReallyConstant(*input) &&
+ !OptimizedNodeExists(*input, "_enter")) {
+ auto fanouts = node_map_->GetOutputs(node_name);
+ // Find non-constant nodes that consume the output of *node.
+ std::vector<NodeDef*> consumers;
+ for (NodeDef* fanout : fanouts) {
+ if (!IsConstant(*fanout)) {
+ for (int i = 0; i < fanout->input_size(); ++i) {
+ if (fanout->input(i) == node_name) {
+ consumers.push_back(fanout);
+ break;
+ }
+ }
+ }
+ }
+ if (!consumers.empty()) {
+ NodeDef* new_node = optimized_graph->add_node();
+ *new_node = *input;
+ new_node->set_name(OptimizedNodeName(*input, "_enter"));
+ new_node->set_device(node->device());
+ new_node->clear_input();
+ new_node->add_input(AsControlDependency(node_name));
+ node_map_->AddNode(new_node->name(), new_node);
+ node_map_->AddOutput(node_name, new_node->name());
for (NodeDef* consumer : consumers) {
for (int i = 0; i < consumer->input_size(); ++i) {
- const string& input = consumer->input(i);
- if (input == false_port) {
- consumer->set_input(i, false_node->name());
- node_map_->UpdateInput(consumer->name(), false_port,
- false_node->name());
- } else if (input == true_port) {
- consumer->set_input(i, true_node->name());
- node_map_->UpdateInput(consumer->name(), true_port,
- true_node->name());
+ if (NodeName(consumer->input(i)) == node_name) {
+ node_map_->UpdateInput(consumer->name(), node_name,
+ new_node->name());
+ consumer->set_input(i, new_node->name());
}
}
}
graph_modified_ = true;
- continue;
+ return Status::OK();
}
}
- if (IsSimplifiableReduction(*node)) {
- // Replace the reduction node with an identity node, that can be further
- // optimized by the model pruner.
- DataType output_type;
- if (node->attr().count("T") > 0) {
- output_type = node->attr().at("T").type();
- } else {
- // This is an 'any' or 'all' reduction. The output is always boolean.
- output_type = DT_BOOL;
+ }
+
+ // Switch(x, x) will always feed false to its false branch and true to
+ // its true branch. By rewriting the graph a bit, we can propagate these
+ // constants down the two output branches, and just use control dependencies
+ // to trigger the selected one at runtime. For example,
+ //
+ // +------+
+ // x-->|Switch|-->a (in practice there may be multiple consumers of each
+ // x-->| |-->b output branch.)
+ // +------+
+ //
+ // Is rewritten as
+ //
+ // +------+
+ // x-->|Switch|-->Identity--^>Const(false)-->a
+ // x-->| |-->Identity--^>Const(true)-->b
+ // +------+
+ if (node->op() == "Switch" && node->input(0) == node->input(1) &&
+ !OptimizedNodeExists(*node, "_const_false") &&
+ !OptimizedNodeExists(*node, "_const_true")) {
+ bool already_optimized = true;
+ // If the optimization was already applied, the switch would have exactly
+ // one Identity node consuming each of its outputs, each without any
+ // non-control outputs.
+ auto fanouts = node_map_->GetOutputs(node->name());
+ if (fanouts.size() == 2) {
+ for (NodeDef* fanout : fanouts) {
+ if (!IsIdentity(*fanout) ||
+ NumNonControlOutputs(*fanout, *node_map_) > 0) {
+ already_optimized = false;
+ break;
+ }
}
- node->set_op("Identity");
- node->clear_attr();
- (*node->mutable_attr())["T"].set_type(output_type);
- *node->mutable_input(1) = AsControlDependency(node->input(1));
- graph_modified_ = true;
- continue;
}
- if (use_shape_info && IsSimplifiableReshape(*node, *properties)) {
- DataType output_type = node->attr().at("T").type();
- node->set_op("Identity");
- node->clear_attr();
- (*node->mutable_attr())["T"].set_type(output_type);
- *node->mutable_input(1) = AsControlDependency(node->input(1));
- graph_modified_ = true;
- continue;
- }
-
- const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
- const bool is_matmul = IsMatMul(*node);
- const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
- const bool is_sub = IsSub(*node);
- const bool is_any_div = IsAnyDiv(*node);
- // Simplify arithmetic operations with ones or zeros.
- if (use_shape_info &&
- (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
- properties->HasInputProperties(node->name()) &&
- properties->HasOutputProperties(node->name())) {
- const NodeDef* x = node_map_->GetNode(node->input(0));
- const NodeDef* y = node_map_->GetNode(node->input(1));
- if (x == nullptr || y == nullptr) {
- return errors::InvalidArgument("Invalid inputs to node: ",
- node->DebugString());
- }
- const TensorShapeProto& output_shape =
- properties->GetOutputProperties(node->name())[0].shape();
-
- // Simplify element-wise multiplication by ones or addition/subtraction
- // of zeros.
- const TensorShapeProto& y_shape =
- properties->GetInputProperties(node->name())[1].shape();
- const bool x_is_zero = IsZeros(*x);
- const bool x_is_one = x_is_zero ? false : IsOnes(*x);
- const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
- if (y_matches_output_shape &&
- ((is_mul && x_is_one) || (is_add && x_is_zero))) {
- // 1 * y = y or 0 + y = y.
- ReplaceOperationWithSnapshot(1, *properties, node, optimized_graph);
- continue;
+ Tensor false_t(DT_BOOL, TensorShape({}));
+ Tensor true_t(DT_BOOL, TensorShape({}));
+ // Make sure we don't proceed if this switch node was already optimized.
+ if (!already_optimized && SetTensorValue(DT_BOOL, true, &true_t).ok() &&
+ SetTensorValue(DT_BOOL, false, &false_t).ok()) {
+ // Copy the set of consumers of the switch as they will be manipulated
+ // below.
+ const std::set<NodeDef*>& consumer_set =
+ node_map_->GetOutputs(node->name());
+ std::vector<NodeDef*> consumers(consumer_set.begin(), consumer_set.end());
+ std::sort(consumers.begin(), consumers.end(),
+ [](const NodeDef* n1, const NodeDef* n2) {
+ return n1->name() < n2->name();
+ });
+ // Create constant false & true nodes.
+ NodeDef* false_node = optimized_graph->add_node();
+ false_node->set_name(OptimizedNodeName(*node, "_const_false"));
+ if (!CreateNodeDef(false_node->name(), TensorValue(&false_t), false_node)
+ .ok()) {
+ return Status::OK();
}
+ false_node->set_device(node->device());
- if (y_matches_output_shape && (is_sub && x_is_zero)) {
- // Replace 0 - y with Neg(y).
- ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
- continue;
+ NodeDef* true_node = optimized_graph->add_node();
+ true_node->set_name(OptimizedNodeName(*node, "_const_true"));
+ if (!CreateNodeDef(true_node->name(), TensorValue(&true_t), true_node)
+ .ok()) {
+ return Status::OK();
}
-
- // Replace 1 / y with Reciprocal op.
- if (y_matches_output_shape && is_any_div && x_is_one) {
- DataType type = node->attr().at("T").type();
- if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
- ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
- continue;
+ true_node->set_device(node->device());
+
+ // Add controls from the switch ports to the constants, and connect the
+ // constants to the original switch outputs.
+ const string false_port = node->name();
+ const string true_port = strings::StrCat(node->name(), ":1");
+ const string false_ctrl_dep =
+ AddControlDependency(false_port, optimized_graph, node_map_.get());
+ false_node->add_input(false_ctrl_dep);
+ const string true_ctrl_dep =
+ AddControlDependency(true_port, optimized_graph, node_map_.get());
+ true_node->add_input(true_ctrl_dep);
+
+ node_map_->AddNode(false_node->name(), false_node);
+ node_map_->AddNode(true_node->name(), true_node);
+ node_map_->AddOutput(NodeName(false_ctrl_dep), false_node->name());
+ node_map_->AddOutput(NodeName(true_ctrl_dep), true_node->name());
+
+ for (NodeDef* consumer : consumers) {
+ for (int i = 0; i < consumer->input_size(); ++i) {
+ const string& input = consumer->input(i);
+ if (input == false_port) {
+ consumer->set_input(i, false_node->name());
+ node_map_->UpdateInput(consumer->name(), false_port,
+ false_node->name());
+ } else if (input == true_port) {
+ consumer->set_input(i, true_node->name());
+ node_map_->UpdateInput(consumer->name(), true_port,
+ true_node->name());
+ }
}
}
+ graph_modified_ = true;
+ return Status::OK();
+ }
+ }
+ if (IsSimplifiableReduction(*node)) {
+ // Replace the reduction node with an identity node, that can be further
+ // optimized by the model pruner.
+ DataType output_type;
+ if (node->attr().count("T") > 0) {
+ output_type = node->attr().at("T").type();
+ } else {
+ // This is an 'any' or 'all' reduction. The output is always boolean.
+ output_type = DT_BOOL;
+ }
+ node->set_op("Identity");
+ node->clear_attr();
+ (*node->mutable_attr())["T"].set_type(output_type);
+ *node->mutable_input(1) = AsControlDependency(node->input(1));
+ graph_modified_ = true;
+ return Status::OK();
+ }
+ if (use_shape_info && IsSimplifiableReshape(*node, *properties)) {
+ DataType output_type = node->attr().at("T").type();
+ node->set_op("Identity");
+ node->clear_attr();
+ (*node->mutable_attr())["T"].set_type(output_type);
+ *node->mutable_input(1) = AsControlDependency(node->input(1));
+ graph_modified_ = true;
+ return Status::OK();
+ }
- const TensorShapeProto& x_shape =
- properties->GetInputProperties(node->name())[0].shape();
- const bool y_is_zero = IsZeros(*y);
- const bool y_is_one = y_is_zero ? false : IsOnes(*y);
- const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
- if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
- ((is_add || is_sub) && y_is_zero))) {
- // x * 1 = x or x / 1 = x or x +/- 0 = x
- ReplaceOperationWithSnapshot(0, *properties, node, optimized_graph);
- continue;
- }
+ const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
+ const bool is_matmul = IsMatMul(*node);
+ const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
+ const bool is_sub = IsSub(*node);
+ const bool is_any_div = IsAnyDiv(*node);
+ // Simplify arithmetic operations with ones or zeros.
+ if (use_shape_info &&
+ (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
+ properties->HasInputProperties(node->name()) &&
+ properties->HasOutputProperties(node->name())) {
+ const NodeDef* x = node_map_->GetNode(node->input(0));
+ const NodeDef* y = node_map_->GetNode(node->input(1));
+ if (x == nullptr || y == nullptr) {
+ return errors::InvalidArgument("Invalid inputs to node: ",
+ node->DebugString());
+ }
+ const TensorShapeProto& output_shape =
+ properties->GetOutputProperties(node->name())[0].shape();
+
+ // Simplify element-wise multiplication by ones or addition/subtraction
+ // of zeros.
+ const TensorShapeProto& y_shape =
+ properties->GetInputProperties(node->name())[1].shape();
+ const bool x_is_zero = IsZeros(*x);
+ const bool x_is_one = x_is_zero ? false : IsOnes(*x);
+ const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
+ if (y_matches_output_shape &&
+ ((is_mul && x_is_one) || (is_add && x_is_zero))) {
+ // 1 * y = y or 0 + y = y.
+ ReplaceOperationWithSnapshot(1, *properties, node, optimized_graph);
+ return Status::OK();
+ }
- // x OR true = true OR y = true.
- const PartialTensorShape shp(output_shape);
- if (shp.IsFullyDefined() && IsLogicalOr(*node) &&
- (y_is_one || x_is_one)) {
- TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
- 1, *properties, output_shape, node, optimized_graph));
- }
-
- // Simplify multiplication and matmul by zeros.
- // Also optimize zeros divided by a tensor, but only if we are in
- // aggressive mode, since we might get rid of divisions by zero.
- bool optimize_zeros_divided_by_y =
- is_any_div && x_is_zero && is_aggressive;
- if ((x_is_zero || y_is_zero) &&
- (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
- if (shp.IsFullyDefined()) {
- TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
- 0, *properties, output_shape, node, optimized_graph));
- continue;
- }
- // Even if an input shape is only partially known, we may known that it
- // matches the output shape and thus forward the corresponding zero
- // input.
- if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- continue;
- } else if (is_mul && y_is_zero && y_matches_output_shape) {
- ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
- continue;
- }
- }
+ if (y_matches_output_shape && (is_sub && x_is_zero)) {
+ // Replace 0 - y with Neg(y).
+ ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
+ return Status::OK();
}
- // Strength reduce floating point division by a constant Div(x, const) to
- // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
- // will be constant folded to Mul(x, 1.0/const).
- if (node->input_size() >= 2 && (IsRealDiv(*node) || IsDiv(*node))) {
- const string& const_input = node->input(1);
- const NodeDef* denom = node_map_->GetNode(const_input);
- CHECK(denom != nullptr);
- if (!IsReallyConstant(*denom)) {
- continue;
- }
- if (node->attr().count("T") == 0) {
- continue;
- }
+ // Replace 1 / y with Reciprocal op.
+ if (y_matches_output_shape && is_any_div && x_is_one) {
DataType type = node->attr().at("T").type();
- if (IsDiv(*node) &&
- !(DataTypeIsFloating(type) || DataTypeIsComplex(type))) {
- continue;
+ if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
+ ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
+ return Status::OK();
}
- // Insert new reciprocal op and change node from Div to Mul.
- NodeDef* reciprocal_node = optimized_graph->add_node();
- reciprocal_node->set_name(OptimizedNodeName(*node, "_recip"));
- reciprocal_node->set_op("Reciprocal");
- reciprocal_node->set_device(node->device());
- node->set_op("Mul");
- // Re-wire inputs and outputs.
- reciprocal_node->add_input(const_input);
- (*reciprocal_node->mutable_attr())["T"].set_type(type);
- node->set_input(1, reciprocal_node->name());
- node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
- node_map_->UpdateOutput(node->name(), const_input,
- reciprocal_node->name());
- graph_modified_ = true;
- continue;
}
- // Consider the transformation
- //
- // + + = parent
- // / \ / \
- // C + -- > X + = children
- // / \ / \
- // X Y C Y = leaves
- //
- // where C is constant and X is non-constant, and '+' denotes an
- // associative and commutative operator like addition or multiplication.
- // This optimization pushes constants down in the tree to canonicalize it.
- // Moreoever, in cases where the child node has a second constant input Y
- // we will create a leaf node that can be folded, e.g.
- //
- // Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
- //
- // TODO(rmlarsen): Handle non-associative/non-commutative operators like
- // subtraction and division, as well as mixed subtraction/addition,
- // division/multiplication.
- // Don't touch BiasAdd since they can't handle vectors as their first
- // inputs.
- if (has_fetch_ && (IsAdd(*node) || is_mul) &&
- NumNonControlInputs(*node) == 2) {
- NodeDef* left_child = node_map_->GetNode(node->input(0));
- NodeDef* right_child = node_map_->GetNode(node->input(1));
- // One child must be constant, and the other the same op as the parent.
- if (node->op() != left_child->op() && node->op() != right_child->op()) {
- continue;
- }
- const bool left_child_is_constant = IsReallyConstant(*left_child);
- const bool right_child_is_constant = IsReallyConstant(*right_child);
- if (!left_child_is_constant && !right_child_is_constant) {
- continue;
- }
- if (node->device() != left_child->device() ||
- node->device() != right_child->device()) {
- continue;
+ const TensorShapeProto& x_shape =
+ properties->GetInputProperties(node->name())[0].shape();
+ const bool y_is_zero = IsZeros(*y);
+ const bool y_is_one = y_is_zero ? false : IsOnes(*y);
+ const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
+ if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
+ ((is_add || is_sub) && y_is_zero))) {
+ // x * 1 = x or x / 1 = x or x +/- 0 = x
+ ReplaceOperationWithSnapshot(0, *properties, node, optimized_graph);
+ return Status::OK();
+ }
+
+ // x OR true = true OR y = true.
+ const PartialTensorShape shp(output_shape);
+ if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
+ TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
+ 1, *properties, output_shape, node, optimized_graph));
+ }
+
+ // Simplify multiplication and matmul by zeros.
+ // Also optimize zeros divided by a tensor, but only if we are in
+ // aggressive mode, since we might get rid of divisions by zero.
+ bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
+ if ((x_is_zero || y_is_zero) &&
+ (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
+ if (shp.IsFullyDefined()) {
+ TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
+ 0, *properties, output_shape, node, optimized_graph));
+ return Status::OK();
}
- NodeDef* op_child_node =
- left_child_is_constant ? right_child : left_child;
- NodeDef* const_child_node =
- left_child_is_constant ? left_child : right_child;
- // Make sure that it is safe to change the value of the child node->
- if (op_child_node->input_size() < 2 ||
- nodes_to_preserve_.find(op_child_node->name()) !=
- nodes_to_preserve_.end() ||
- NumNonControlOutputs(*op_child_node, *node_map_) > 1) {
- continue;
+ // Even if an input shape is only partially known, we may known that it
+ // matches the output shape and thus forward the corresponding zero
+ // input.
+ if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
+ ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ return Status::OK();
+ } else if (is_mul && y_is_zero && y_matches_output_shape) {
+ ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
+ return Status::OK();
}
+ }
+ }
- // Identify the nodes to swap.
- NodeDef* left_leaf = node_map_->GetNode(op_child_node->input(0));
- NodeDef* right_leaf = node_map_->GetNode(op_child_node->input(1));
- const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
- const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
- if (left_leaf_is_constant && right_leaf_is_constant) {
- // Child is already foldable, leave it alone.
- continue;
- }
- const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0;
- const int parent_const_input = left_child_is_constant ? 0 : 1;
- const auto& child_output = node_map_->GetOutputs(op_child_node->name());
- if (child_output.find(const_child_node) != child_output.end()) {
- // If there is a control edge from the child op to C, the transformation
- // would create a cycle in the graph. We know that it must be a control
- // edge. We can replace such a control edge with a control edge from A
- // to C.
- CHECK(MaybeRemoveControlInput(op_child_node->name(), const_child_node,
- graph_, node_map_.get()));
- NodeDef* other_leaf = left_leaf_is_constant ? left_leaf : right_leaf;
- MaybeAddControlInput(other_leaf->name(), const_child_node, graph_,
- node_map_.get());
- }
-
- // Swap the constant child with a non-constant leaf node.
- node_map_->UpdateInput(node->name(), node->input(parent_const_input),
- op_child_node->input(non_const_leaf_input));
- node_map_->UpdateInput(op_child_node->name(),
- op_child_node->input(non_const_leaf_input),
- node->input(parent_const_input));
- std::swap(*node->mutable_input(parent_const_input),
- *op_child_node->mutable_input(non_const_leaf_input));
- graph_modified_ = true;
- continue;
+ // Strength reduce floating point division by a constant Div(x, const) to
+ // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
+ // will be constant folded to Mul(x, 1.0/const).
+ if (node->input_size() >= 2 && (IsRealDiv(*node) || IsDiv(*node))) {
+ const string& const_input = node->input(1);
+ const NodeDef* denom = node_map_->GetNode(const_input);
+ CHECK(denom != nullptr);
+ if (!IsReallyConstant(*denom)) {
+ return Status::OK();
+ }
+ if (node->attr().count("T") == 0) {
+ return Status::OK();
+ }
+ DataType type = node->attr().at("T").type();
+ if (IsDiv(*node) &&
+ !(DataTypeIsFloating(type) || DataTypeIsComplex(type))) {
+ return Status::OK();
}
+ // Insert new reciprocal op and change node from Div to Mul.
+ NodeDef* reciprocal_node = optimized_graph->add_node();
+ reciprocal_node->set_name(OptimizedNodeName(*node, "_recip"));
+ reciprocal_node->set_op("Reciprocal");
+ reciprocal_node->set_device(node->device());
+ node->set_op("Mul");
+ // Re-wire inputs and outputs.
+ reciprocal_node->add_input(const_input);
+ (*reciprocal_node->mutable_attr())["T"].set_type(type);
+ node->set_input(1, reciprocal_node->name());
+ node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
+ node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name());
+ graph_modified_ = true;
+ return Status::OK();
+ }
- // Partial constant propagation through IdentityN.
- if (IsIdentityN(*node) && NumNonControlInputs(*node) > 0) {
- const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
- const std::vector<NodeDef*> consumers(tmp.begin(), tmp.end());
- bool updated_graph = false;
- for (int input_idx = 0; input_idx < node->input_size(); ++input_idx) {
- const string& input = node->input(input_idx);
- if (IsControlInput(input)) {
- break;
- }
- const NodeDef* input_node = node_map_->GetNode(NodeName(input));
- if (input_node == nullptr) {
- LOG(ERROR) << "Bad input: " << input;
- break;
- }
- // Forward constant inputs to outputs and add a control dependency on
- // the IdentityN node.
- if (IsReallyConstant(*input_node)) {
- // Update each consumer.
- for (NodeDef* consumer : consumers) {
- bool add_dep = false;
- for (int consumer_input_idx = 0;
- consumer_input_idx < consumer->input_size();
- ++consumer_input_idx) {
- const string& consumer_input =
- consumer->input(consumer_input_idx);
- if (IsControlInput(consumer_input)) {
- break;
- }
- int output_idx;
- const string input_node_name =
- ParseNodeName(consumer_input, &output_idx);
- if (input_node_name == node->name() && output_idx == input_idx) {
- consumer->set_input(consumer_input_idx, input);
- // We will keep the input from IdentityN through a control
- // dependency, so we only need to add the consumer as an output
- // for the constant input node.
- node_map_->AddOutput(NodeName(input), consumer->name());
- add_dep = true;
- }
+ if (ConstantPushDown(node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (PartialConstPropThroughIdentityN(node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (PartialConcatConstFolding(optimized_graph, properties, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ return Status::OK();
+}
+
+bool ConstantFolding::ConstantPushDown(NodeDef* node) {
+ // Consider the transformation
+ //
+ // + + = parent
+ // / \ / \
+ // C + -- > X + = children
+ // / \ / \
+ // X Y C Y = leaves
+ //
+ // where C is constant and X is non-constant, and '+' denotes an
+ // associative and commutative operator like addition or multiplication.
+ // This optimization pushes constants down in the tree to canonicalize it.
+ // Moreoever, in cases where the child node has a second constant input Y
+ // we will create a leaf node that can be folded, e.g.
+ //
+ // Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
+ //
+ // TODO(rmlarsen): Handle non-associative/non-commutative operators like
+ // subtraction and division, as well as mixed subtraction/addition,
+ // division/multiplication.
+ // Don't touch BiasAdd since they can't handle vectors as their first
+ // inputs.
+ if (has_fetch_ && (IsAdd(*node) || IsMul(*node)) &&
+ NumNonControlInputs(*node) == 2) {
+ NodeDef* left_child = node_map_->GetNode(node->input(0));
+ NodeDef* right_child = node_map_->GetNode(node->input(1));
+ // One child must be constant, and the other the same op as the parent.
+ if (node->op() != left_child->op() && node->op() != right_child->op()) {
+ return false;
+ }
+ const bool left_child_is_constant = IsReallyConstant(*left_child);
+ const bool right_child_is_constant = IsReallyConstant(*right_child);
+ if (!left_child_is_constant && !right_child_is_constant) {
+ return false;
+ }
+ if (node->device() != left_child->device() ||
+ node->device() != right_child->device()) {
+ return false;
+ }
+ NodeDef* op_child_node = left_child_is_constant ? right_child : left_child;
+ NodeDef* const_child_node =
+ left_child_is_constant ? left_child : right_child;
+ // Make sure that it is safe to change the value of the child node->
+ if (op_child_node->input_size() < 2 ||
+ nodes_to_preserve_.find(op_child_node->name()) !=
+ nodes_to_preserve_.end() ||
+ NumNonControlOutputs(*op_child_node, *node_map_) > 1) {
+ return false;
+ }
+
+ // Identify the nodes to swap.
+ NodeDef* left_leaf = node_map_->GetNode(op_child_node->input(0));
+ NodeDef* right_leaf = node_map_->GetNode(op_child_node->input(1));
+ const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
+ const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
+ if (left_leaf_is_constant && right_leaf_is_constant) {
+ // Child is already foldable, leave it alone.
+ return false;
+ }
+ const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0;
+ const int parent_const_input = left_child_is_constant ? 0 : 1;
+ const auto& child_output = node_map_->GetOutputs(op_child_node->name());
+ if (child_output.find(const_child_node) != child_output.end()) {
+ // If there is a control edge from the child op to C, the transformation
+ // would create a cycle in the graph. We know that it must be a control
+ // edge. We can replace such a control edge with a control edge from A
+ // to C.
+ CHECK(MaybeRemoveControlInput(op_child_node->name(), const_child_node,
+ graph_, node_map_.get()));
+ NodeDef* other_leaf = left_leaf_is_constant ? left_leaf : right_leaf;
+ MaybeAddControlInput(other_leaf->name(), const_child_node, graph_,
+ node_map_.get());
+ }
+
+ // Swap the constant child with a non-constant leaf node.
+ node_map_->UpdateInput(node->name(), node->input(parent_const_input),
+ op_child_node->input(non_const_leaf_input));
+ node_map_->UpdateInput(op_child_node->name(),
+ op_child_node->input(non_const_leaf_input),
+ node->input(parent_const_input));
+ std::swap(*node->mutable_input(parent_const_input),
+ *op_child_node->mutable_input(non_const_leaf_input));
+ return true;
+ }
+ return false;
+}
+
+bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
+ // Partial constant propagation through IdentityN.
+ if (IsIdentityN(*node) && NumNonControlInputs(*node) > 0) {
+ const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
+ const std::vector<NodeDef*> consumers(tmp.begin(), tmp.end());
+ bool updated_graph = false;
+ for (int input_idx = 0; input_idx < node->input_size(); ++input_idx) {
+ const string& input = node->input(input_idx);
+ if (IsControlInput(input)) {
+ break;
+ }
+ const NodeDef* input_node = node_map_->GetNode(NodeName(input));
+ if (input_node == nullptr) {
+ LOG(ERROR) << "Bad input: " << input;
+ break;
+ }
+ // Forward constant inputs to outputs and add a control dependency on
+ // the IdentityN node.
+ if (IsReallyConstant(*input_node)) {
+ // Update each consumer.
+ for (NodeDef* consumer : consumers) {
+ bool add_dep = false;
+ for (int consumer_input_idx = 0;
+ consumer_input_idx < consumer->input_size();
+ ++consumer_input_idx) {
+ const string& consumer_input = consumer->input(consumer_input_idx);
+ if (IsControlInput(consumer_input)) {
+ break;
}
- if (add_dep) {
- consumer->add_input(AsControlDependency(node->name()));
- updated_graph = true;
+ int output_idx;
+ const string input_node_name =
+ ParseNodeName(consumer_input, &output_idx);
+ if (input_node_name == node->name() && output_idx == input_idx) {
+ consumer->set_input(consumer_input_idx, input);
+ // We will keep the input from IdentityN through a control
+ // dependency, so we only need to add the consumer as an output
+ // for the constant input node.
+ node_map_->AddOutput(NodeName(input), consumer->name());
+ add_dep = true;
}
}
+ if (add_dep) {
+ consumer->add_input(AsControlDependency(node->name()));
+ updated_graph = true;
+ }
}
}
+ }
- if (updated_graph) {
- for (NodeDef* consumer : consumers) {
- DedupControlInputs(consumer);
- }
- graph_modified_ = true;
- continue;
+ if (updated_graph) {
+ for (NodeDef* consumer : consumers) {
+ DedupControlInputs(consumer);
}
+ return true;
}
+ }
+ return false;
+}
- // Partial constant folding for associative operators:
- // Split AddN/AccumulateNV2 to enable partial
- // folding of ops when more than one but not all inputs are constant.
- // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
- // addition is commutative.
- const int num_non_control_inputs = NumNonControlInputs(*node);
- if (IsAggregate(*node) && IsCommutative(*node) &&
- num_non_control_inputs > 2) {
- const int num_control_inputs =
- node->input_size() - num_non_control_inputs;
- std::vector<int> const_inputs;
- std::vector<int> nonconst_inputs;
- for (int i = 0; i < node->input_size(); ++i) {
- const string& input = node->input(i);
- const NodeDef* input_node = node_map_->GetNode(NodeName(input));
- CHECK(input_node != nullptr) << input;
- if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
- const_inputs.push_back(i);
- } else {
- // Non-const and control inputs.
- nonconst_inputs.push_back(i);
- }
+bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
+ GraphProperties* properties,
+ NodeDef* node) {
+ // Partial constant folding for associative operators:
+ // Split AddN/AccumulateNV2 to enable partial
+ // folding of ops when more than one but not all inputs are constant.
+ // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
+ // addition is commutative.
+ const int num_non_control_inputs = NumNonControlInputs(*node);
+ if (IsAggregate(*node) && IsCommutative(*node) &&
+ num_non_control_inputs > 2) {
+ const int num_control_inputs = node->input_size() - num_non_control_inputs;
+ std::vector<int> const_inputs;
+ std::vector<int> nonconst_inputs;
+ for (int i = 0; i < node->input_size(); ++i) {
+ const string& input = node->input(i);
+ const NodeDef* input_node = node_map_->GetNode(NodeName(input));
+ CHECK(input_node != nullptr) << input;
+ if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
+ const_inputs.push_back(i);
+ } else {
+ // Non-const and control inputs.
+ nonconst_inputs.push_back(i);
}
- // Promote AccumulateNV2 with all constant inputs to AddN, since it is
- // a fake node that cannot be constant folded by itself.
- if (const_inputs.size() == num_non_control_inputs &&
- node->op() == "AccumulateNV2") {
- node->set_op("AddN");
- node->mutable_attr()->erase("shape");
- graph_modified_ = true;
- continue;
+ }
+ // Promote AccumulateNV2 with all constant inputs to AddN, since it is
+ // a fake node that cannot be constant folded by itself.
+ if (const_inputs.size() == num_non_control_inputs &&
+ node->op() == "AccumulateNV2") {
+ node->set_op("AddN");
+ node->mutable_attr()->erase("shape");
+ return true;
+ }
+ const string new_node_name = OptimizedNodeName(
+ *node, strings::StrCat("_partial_split_", const_inputs.size()));
+ if (1 < const_inputs.size() &&
+ const_inputs.size() < num_non_control_inputs &&
+ !node_map_->NodeExists(new_node_name)) {
+ NodeDef* added_node = optimized_graph->add_node();
+ *added_node = *node;
+ // Always use AddN for the constant node, since AccumulateNV2 is a fake
+ // node that cannot be constant folded, since it does not have a kernel.
+ added_node->set_op("AddN");
+ added_node->mutable_attr()->erase("shape");
+ added_node->set_name(new_node_name);
+ node_map_->AddNode(added_node->name(), added_node);
+ added_node->clear_input();
+ for (int i : const_inputs) {
+ added_node->add_input(node->input(i));
+ node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
+ added_node->name());
}
- const string new_node_name = OptimizedNodeName(
- *node, strings::StrCat("_partial_split_", const_inputs.size()));
- if (1 < const_inputs.size() &&
- const_inputs.size() < num_non_control_inputs &&
- !node_map_->NodeExists(new_node_name)) {
- NodeDef* added_node = optimized_graph->add_node();
- *added_node = *node;
- // Always use AddN for the constant node, since AccumulateNV2 is a fake
- // node that cannot be constant folded, since it does not have a kernel.
- added_node->set_op("AddN");
- added_node->mutable_attr()->erase("shape");
- added_node->set_name(new_node_name);
- node_map_->AddNode(added_node->name(), added_node);
- added_node->clear_input();
- for (int i : const_inputs) {
- added_node->add_input(node->input(i));
- node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
- added_node->name());
- }
- // Overwrite the first const input with the added node.
- node->set_input(const_inputs[0], added_node->name());
- node_map_->AddOutput(added_node->name(), node->name());
- nonconst_inputs.push_back(const_inputs[0]);
- // Compact the remaining inputs to the original node.
- std::sort(nonconst_inputs.begin(), nonconst_inputs.end());
- int idx = 0;
- for (int i : nonconst_inputs) {
- if (idx != i) {
- node->set_input(idx, node->input(i));
- }
- ++idx;
+ // Overwrite the first const input with the added node.
+ node->set_input(const_inputs[0], added_node->name());
+ node_map_->AddOutput(added_node->name(), node->name());
+ nonconst_inputs.push_back(const_inputs[0]);
+ // Compact the remaining inputs to the original node.
+ std::sort(nonconst_inputs.begin(), nonconst_inputs.end());
+ int idx = 0;
+ for (int i : nonconst_inputs) {
+ if (idx != i) {
+ node->set_input(idx, node->input(i));
}
- node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
- const_inputs.size() - 1);
- (*node->mutable_attr())["N"].set_i(node->input_size() -
- num_control_inputs);
- properties->ClearInputProperties(node->name());
- (*added_node->mutable_attr())["N"].set_i(const_inputs.size());
- graph_modified_ = true;
- continue;
- }
- }
-
- if (PartialConcatConstFolding(optimized_graph, properties, node)) {
- graph_modified_ = true;
- continue;
+ ++idx;
+ }
+ node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
+ const_inputs.size() - 1);
+ (*node->mutable_attr())["N"].set_i(node->input_size() -
+ num_control_inputs);
+ properties->ClearInputProperties(node->name());
+ (*added_node->mutable_attr())["N"].set_i(const_inputs.size());
+ return true;
}
}
-
- return Status::OK();
+ return false;
}
bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
@@ -2545,6 +2570,17 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
cpu_device_ = owned_device_.get();
}
+ graph_contains_assign_or_inplace_op_ = false;
+ // TODO(rmlarsen): Enable in regular mode after May 15, 2018.
+ if (opt_level_ == RewriterConfig::AGGRESSIVE) {
+ for (const NodeDef& node : item.graph.node()) {
+ if (ModifiesInputsInPlace(node) || MaybeHasRefInput(node)) {
+ graph_contains_assign_or_inplace_op_ = true;
+ break;
+ }
+ }
+ }
+
has_fetch_ = !item.fetch.empty();
GrapplerItem item_to_optimize = item;
*optimized_graph = item.graph;
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 2096576538..227caba7ee 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -97,6 +97,8 @@ class ConstantFolding : public GraphOptimizer {
const GraphProperties& properties) const;
Status SimplifyGraph(GraphDef* output, GraphProperties* properties,
bool use_shape_info);
+ Status SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
+ GraphProperties* properties, bool use_shape_info);
Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
GraphDef* output);
@@ -106,6 +108,19 @@ class ConstantFolding : public GraphOptimizer {
bool PartialConcatConstFolding(GraphDef* optimized_graph,
GraphProperties* properties, NodeDef* node);
+ // Applies partial constant folding for associative operators AddN and
+ // AccumulateNV2. Returns true if the transformation applied successfully.
+ bool PartialAssocOpConstFolding(GraphDef* optimized_graph,
+ GraphProperties* properties, NodeDef* node);
+
+ // Applies partial constant propagation through IdentityN operator.
+ // Returns true if the transformation applied successfully.
+ bool PartialConstPropThroughIdentityN(NodeDef* node);
+
+ // Pushes down constants on '+' and '*' operators if applicable. Returns true
+ // the transformation applied successfully.
+ bool ConstantPushDown(NodeDef* node);
+
// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;
@@ -119,6 +134,7 @@ class ConstantFolding : public GraphOptimizer {
std::unordered_set<string> feed_nodes_;
bool has_fetch_;
bool graph_modified_;
+ bool graph_contains_assign_or_inplace_op_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index f018b217e6..0bf51c48f7 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -33,77 +33,89 @@ class ConstantFoldingTest : public GrapplerTest {
protected:
template <DataType DTYPE>
void SimpleNeutralElementTest() {
- typedef typename EnumToDataType<DTYPE>::Type T;
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
- ops::Placeholder::Shape(TensorShape({2, 2})));
- Tensor zeros_t(DTYPE, TensorShape({2, 2}));
- Tensor ones_t(DTYPE, TensorShape({2, 2}));
- Tensor x_t(DTYPE, TensorShape({2, 2}));
- for (int i = 0; i < 4; ++i) {
- zeros_t.flat<T>()(i) = T(0);
- ones_t.flat<T>()(i) = T(1);
- x_t.flat<T>()(i) = T(i + 1);
- }
- Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
- Output ones = ops::Const(s.WithOpName("ones"), ones_t);
- Output mul1;
- Output mul2;
- Output add1;
- Output add2;
- if (DTYPE == DT_BOOL) {
- mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
- mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
- add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
- add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
- } else {
- mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
- mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
- add1 = ops::Add(s.WithOpName("add1"), x, zeros);
- add1 = ops::Add(s.WithOpName("add2"), x, ones);
- }
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- item.fetch = {"mul1", "mul2", "add1", "add2"};
- ConstantFolding optimizer(nullptr /* cpu_device */);
- GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
-
- EXPECT_EQ(7, output.node_size());
- for (int i = 0; i < output.node_size(); ++i) {
- const NodeDef& node = output.node(i);
- const string& name = node.name();
- if (name == "mul1") {
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
- } else if (name == "mul2") {
- EXPECT_EQ("Snapshot", node.op());
- EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^ones", node.input(1));
- } else if (name == "add1") {
- EXPECT_EQ("Snapshot", node.op());
- EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("^zeros", node.input(1));
- } else if (name == "add2") {
- if (DTYPE == DT_BOOL) {
+ for (bool use_snapshot : {false, true}) {
+ typedef typename EnumToDataType<DTYPE>::Type T;
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Output v = ops::Variable(s.WithOpName("v"), {2, 2}, DTYPE);
+ Tensor zeros_t(DTYPE, TensorShape({2, 2}));
+ Tensor ones_t(DTYPE, TensorShape({2, 2}));
+ Tensor x_t(DTYPE, TensorShape({2, 2}));
+ for (int i = 0; i < 4; ++i) {
+ zeros_t.flat<T>()(i) = T(0);
+ ones_t.flat<T>()(i) = T(1);
+ x_t.flat<T>()(i) = T(i + 1);
+ }
+ Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
+ Output ones = ops::Const(s.WithOpName("ones"), ones_t);
+ Output mul1;
+ Output mul2;
+ Output add1;
+ Output add2;
+ if (DTYPE == DT_BOOL) {
+ mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros);
+ mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones);
+ add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros);
+ add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones);
+ } else {
+ mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
+ mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
+ add1 = ops::Add(s.WithOpName("add1"), x, zeros);
+ add1 = ops::Add(s.WithOpName("add2"), x, ones);
+ }
+ if (use_snapshot) {
+ // Add an op with ref input to prevent Snapshot from being
+ // turned into Identity.
+ ops::Assign(s.WithOpName("assign"), v, ones);
+ }
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"mul1", "mul2", "add1", "add2"};
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(7, output.node_size());
+ const string snapshot_or_identity =
+ use_snapshot ? "Snapshot" : "Identity";
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ const string& name = node.name();
+ if (name == "mul1") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ("^x", node.input(0));
+ EXPECT_EQ("^zeros", node.input(1));
+ } else if (name == "mul2") {
+ EXPECT_EQ(snapshot_or_identity, node.op());
+ EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^ones", node.input(1));
- } else {
- EXPECT_EQ("Add", node.op());
+ } else if (name == "add1") {
+ EXPECT_EQ(snapshot_or_identity, node.op());
EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("ones", node.input(1));
+ EXPECT_EQ("^zeros", node.input(1));
+ } else if (name == "add2") {
+ if (DTYPE == DT_BOOL) {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^x", node.input(0));
+ EXPECT_EQ("^ones", node.input(1));
+ } else {
+ EXPECT_EQ("Add", node.op());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("ones", node.input(1));
+ }
}
}
- }
- auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
- auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
- EXPECT_EQ(4, tensors_expected.size());
- EXPECT_EQ(4, tensors.size());
- for (int i = 0; i < item.fetch.size(); ++i) {
- test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
+ auto tensors_expected =
+ EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
+ auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
+ EXPECT_EQ(4, tensors_expected.size());
+ EXPECT_EQ(4, tensors.size());
+ for (int i = 0; i < item.fetch.size(); ++i) {
+ test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
+ }
}
}
};
@@ -284,7 +296,8 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch = {"stack", "matmul3", "matmul4"};
- ConstantFolding optimizer(nullptr /* cpu_device */);
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -309,11 +322,11 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ("^y", node.input(1));
} else if (name == "mul3") {
- EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "mul4") {
- EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "mul5") {
@@ -325,7 +338,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
EXPECT_EQ("^zeros_1d", node.input(0));
EXPECT_EQ("^y", node.input(1));
} else if (name == "div1") {
- EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ(ctrl_ones_name, node.input(1));
} else if (name == "div2") {
@@ -361,15 +374,15 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
EXPECT_EQ(2, t.tensor_shape().dim(0).size());
EXPECT_EQ(3, t.tensor_shape().dim(1).size());
} else if (name == "add1") {
- EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "add2") {
- EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "bias_add1") {
- EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^zeros_1d", node.input(1));
} else if (name == "bias_add2") {
@@ -378,7 +391,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
EXPECT_EQ(zeros_name, node.input(0));
EXPECT_EQ("bias", node.input(1));
} else if (name == "sub1") {
- EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "sub2") {
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index a44e1ee7f9..611d871eea 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -98,7 +98,7 @@ struct FunctionSpecializationSignature {
for (const auto& lhs : body_parameters) {
auto it = other.body_parameters.find(lhs.first);
if (it == other.body_parameters.end()) return false;
- if (!AreAttrValuesEqual(lhs.second, (*it).second)) return false;
+ if (!FastAreAttrValuesEqual(lhs.second, (*it).second)) return false;
}
return true;
@@ -123,7 +123,7 @@ struct FunctionSpecializationSignature {
s.body_parameters.end());
for (const auto& pair : body) {
h = Hash64Combine(Hash64(pair.first), h);
- h = Hash64Combine(AttrValueHash(pair.second), h);
+ h = Hash64Combine(FastAttrValueHash(pair.second), h);
}
std::map<int, string> inputs(s.const_inputs.begin(),
@@ -144,11 +144,18 @@ struct FunctionSpecialization {
std::unordered_set<string> control_deps;
};
+class FakeCPUDevice : public Device {
+ public:
+ FakeCPUDevice(Env* env, const DeviceAttributes& attr) : Device(env, attr) {}
+ Status Sync() override { return Status::OK(); }
+};
+
class FunctionOptimizerContext {
public:
explicit FunctionOptimizerContext(RewriterConfig::Toggle opt_level,
const GrapplerItem& item)
- : function_library_(OpRegistry::Global(), item.graph.library()) {
+ : graph_version_(item.graph.versions().producer()),
+ function_library_(OpRegistry::Global(), item.graph.library()) {
InitializeTrulyConstNodes(item);
InitializeInlinedFunctions(opt_level, item);
}
@@ -161,6 +168,11 @@ class FunctionOptimizerContext {
return &function_library_;
}
+ FunctionLibraryRuntime* mutable_function_library_runtime() {
+ InitializeFunctionLibraryRuntime();
+ return flr_;
+ }
+
bool IsInlinedFunction(const string& name) const {
return inlined_functions_.count(name) > 0;
}
@@ -222,12 +234,35 @@ class FunctionOptimizerContext {
}
}
+ void InitializeFunctionLibraryRuntime() {
+ if (!flr_) {
+ Env* env = Env::Default();
+ DeviceAttributes attr;
+ attr.set_name("/device:CPU:0");
+ attr.set_device_type("CPU");
+ Device* device = new FakeCPUDevice(env, attr);
+ device_mgr_.reset(new DeviceMgr({device}));
+ OptimizerOptions optimizer_opts;
+ optimizer_opts.set_do_function_inlining(true);
+ process_flr_.reset(new ProcessFunctionLibraryRuntime(
+ device_mgr_.get(), env, graph_version_, &function_library_,
+ optimizer_opts));
+ flr_ = process_flr_->GetFLR(device->name());
+ }
+ }
+
+ const int graph_version_;
FunctionLibraryDefinition function_library_;
+
+ // These fields initialized lazily only if needed.
+ std::unique_ptr<DeviceMgr> device_mgr_;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> process_flr_;
+ FunctionLibraryRuntime* flr_ = nullptr;
+
// Functions that can be inlined into optimized graph.
std::unordered_map<string, const FunctionDef*> inlined_functions_;
// Nodes that are Const and not in feed.
std::unordered_map<string, const NodeDef*> truly_const_nodes_;
-
// Specialized functions.
std::unordered_map<FunctionSpecializationSignature,
const FunctionSpecialization,
@@ -497,63 +532,46 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
return Status::OK();
}
-// Copy input/output argument type to the type_list. Return error if argument
-// type is not explicitly defined, and not specified in function attributes.
-Status CopyArgType(const NodeDef& func_node,
- const std::unordered_map<string, AttrValue>& func_attr,
- const string& arg_kind, const OpDef::ArgDef& arg,
- AttrValue::ListValue* type_list) {
- if (arg.type() != DT_INVALID) {
- type_list->add_type(arg.type());
- } else {
- auto it = func_attr.find(arg.type_attr());
- if (it == func_attr.end() || it->second.type() == DT_INVALID) {
- return errors::InvalidArgument(
- "Invalid ", arg_kind, " argument ", arg.name(), " for function ",
- func_node.op(), " instantiated by ", func_node.name());
- }
- type_list->add_type(it->second.type());
- }
- return Status::OK();
-}
-
-// Add an IdentityN op to hook the function inputs to: this ensures that
+// Create an IdentityN node to hook the function inputs to: this ensures that
// they're all evaluated before the evaluation of the function body starts.
-Status HookInlinedFunctionInputs(
- const NodeDef& func_node, const FunctionDef& func,
- const std::unordered_map<string, AttrValue>& func_attr, NodeDef* inputs) {
- inputs->set_name(strings::StrCat(func_node.name(), "/", "inlined_inputs"));
- inputs->set_op("IdentityN");
- inputs->set_device(func_node.device());
- *inputs->mutable_input() = func_node.input();
+NodeDef InlinedFunctionInputsNode(const NodeDef& func_node,
+ const GrapplerFunctionItem& item) {
+ NodeDef inputs;
+ inputs.set_name(strings::StrCat(func_node.name(), "/", "inlined_inputs"));
+ inputs.set_op("IdentityN");
+ inputs.set_device(func_node.device());
+ *inputs.mutable_input() = func_node.input();
AttrValue::ListValue* type_list =
- (*inputs->mutable_attr())["T"].mutable_list();
- for (const OpDef::ArgDef& arg : func.signature().input_arg()) {
- TF_RETURN_IF_ERROR(
- CopyArgType(func_node, func_attr, "input", arg, type_list));
+ (*inputs.mutable_attr())["T"].mutable_list();
+
+ for (const InputArgExpansion& input_arg : item.inputs()) {
+ for (int i = 0; i < input_arg.placeholders.size(); ++i) {
+ type_list->add_type(input_arg.data_type);
+ }
}
- return Status::OK();
+
+ return inputs;
}
-// Add an IdentityN op to hook the function outputs to: this ensures that the
-// function body is fully evaluated before its fanout gets scheduled.
-Status HookInlinedFunctionOutputs(
- const NodeDef& func_node, const FunctionDef& func,
- const std::unordered_map<string, AttrValue>& func_attr,
- const gtl::ArraySlice<string> fetch, NodeDef* outputs) {
- outputs->set_name(func_node.name());
- outputs->set_op("IdentityN");
- outputs->set_device(func_node.device());
+// Create an IdentityN node to hook the function outputs to: this ensures that
+// the function body is fully evaluated before its fanout gets scheduled.
+NodeDef InlinedFunctionOutputsNode(const NodeDef& func_node,
+ const GrapplerFunctionItem& item) {
+ NodeDef outputs;
+ outputs.set_name(func_node.name());
+ outputs.set_op("IdentityN");
+ outputs.set_device(func_node.device());
AttrValue::ListValue* type_list =
- (*outputs->mutable_attr())["T"].mutable_list();
- for (int i = 0; i < func.signature().output_arg_size(); ++i) {
- const OpDef::ArgDef& arg = func.signature().output_arg(i);
- TF_RETURN_IF_ERROR(
- CopyArgType(func_node, func_attr, "output", arg, type_list));
- // Use the fetch names since they take into account the output mapping.
- outputs->add_input(strings::StrCat(func_node.name(), "/", fetch[i]));
+ (*outputs.mutable_attr())["T"].mutable_list();
+
+ for (const OutputArgExpansion& output_arg : item.outputs()) {
+ for (const string& output_tensor : output_arg.output_tensors) {
+ type_list->add_type(output_arg.data_type);
+ outputs.add_input(strings::StrCat(func_node.name(), "/", output_tensor));
+ }
}
- return Status::OK();
+
+ return outputs;
}
Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
@@ -574,27 +592,27 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
". Error: ", item_status.error_message());
}
- std::unordered_map<string, int> input_nodes;
- for (int i = 0; i < func.signature().input_arg_size(); ++i) {
- const OpDef::ArgDef& arg = func.signature().input_arg(i);
- input_nodes[arg.name()] = i;
+ // Mapping from input placeholder name to function input position.
+ int idx = 0;
+ std::unordered_map<string, int> input_placeholders_idx;
+ for (const InputArgExpansion& input_arg : item.inputs()) {
+ for (const string& placeholder : input_arg.placeholders) {
+ input_placeholders_idx[placeholder] = idx++;
+ }
}
- // Hook inlined function inputs to IdentityN node
+ // Hook inlined function inputs to IdentityN node.
NodeDef* func_inputs = optimized_graph->add_node();
- TF_RETURN_IF_ERROR(
- HookInlinedFunctionInputs(func_node, func, func_attr, func_inputs));
+ *func_inputs = InlinedFunctionInputsNode(func_node, item);
for (NodeDef& func_body_node : *item.mutable_function_body().mutable_node()) {
- if (input_nodes.find(func_body_node.name()) != input_nodes.end()) {
+ if (item.IsInputPlaceholder(func_body_node.name())) {
+ // Turn input placeholders into identity nodes.
CHECK_EQ(0, func_body_node.input_size());
- // Turn input placeholders into identity nodes
- if (IsPlaceholder(func_body_node)) {
- func_body_node.set_op("Identity");
- }
- int input_id = input_nodes[func_body_node.name()];
+ func_body_node.set_op("Identity");
+ int input_idx = input_placeholders_idx[func_body_node.name()];
func_body_node.add_input(
- strings::StrCat(func_inputs->name(), ":", input_id));
+ strings::StrCat(func_inputs->name(), ":", input_idx));
} else {
// Update the input names if any.
for (string& input : *func_body_node.mutable_input()) {
@@ -608,18 +626,18 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
}
}
- // Add the node name as a prefix to avoid collisions after inlining
+ // Add the node name as a prefix to avoid collisions after inlining.
func_body_node.set_name(
strings::StrCat(func_node.name(), "/", func_body_node.name()));
- // Make sure the node is placed
+ // Make sure the node is placed.
func_body_node.set_device(func_node.device());
- // Check if a body node is itself a function
+ // Check if a body node is itself a function.
const FunctionDef* func_body_node_func =
ctx.FindInlinedFunction(func_body_node.op());
if (func_body_node_func != nullptr) {
- // Recursively inline function calls
+ // Recursively inline function calls.
TF_RETURN_IF_ERROR(InlineFunction(func_body_node, *func_body_node_func,
ctx, optimized_graph));
} else {
@@ -627,72 +645,20 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
for (const auto& attr : func.attr()) {
func_body_node.mutable_attr()->insert(attr);
}
- // Move the node to the main graph
+ // Move the node to the main graph.
optimized_graph->add_node()->Swap(&func_body_node);
}
}
- // Hook inlined function outputs to IdentityN node
+ // Hook inlined function outputs to IdentityN node.
NodeDef* func_outputs = optimized_graph->add_node();
- std::vector<string> fetch = OutputTensors(item);
- TF_RETURN_IF_ERROR(HookInlinedFunctionOutputs(func_node, func, func_attr,
- fetch, func_outputs));
+ *func_outputs = InlinedFunctionOutputsNode(func_node, item);
return Status::OK();
}
-class FakeCPUDevice : public Device {
- public:
- FakeCPUDevice(Env* env, const DeviceAttributes& attr) : Device(env, attr) {}
- Status Sync() override { return Status::OK(); }
-};
-
-class SymbolicGradientEnv {
- public:
- SymbolicGradientEnv(int graph_version, const FunctionDefLibrary& library)
- : graph_version_(graph_version), library_(library) {}
-
- FunctionLibraryDefinition* function_library() {
- InitializeIfNeeded();
- return fld_.get();
- }
- FunctionLibraryRuntime* function_library_runtime() {
- InitializeIfNeeded();
- return flr_;
- }
-
- private:
- // This initialization is expensive. Do it lazily to avoid paying for it
- // unless it's needed.
- void InitializeIfNeeded() {
- if (flr_) {
- return;
- }
- Env* env = Env::Default();
- DeviceAttributes attr;
- attr.set_name("/device:CPU:0");
- attr.set_device_type("CPU");
- FakeCPUDevice* dev = new FakeCPUDevice(env, attr);
- std::vector<Device*> devices;
- devices.push_back(dev);
- dvc_mgr_.reset(new DeviceMgr(devices));
- fld_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), library_));
- OptimizerOptions optimizer_opts;
- optimizer_opts.set_do_function_inlining(true);
- pflr_.reset(new ProcessFunctionLibraryRuntime(
- dvc_mgr_.get(), env, graph_version_, fld_.get(), optimizer_opts));
- flr_ = pflr_->GetFLR(dev->name());
- }
-
- const int graph_version_;
- const FunctionDefLibrary& library_;
- std::unique_ptr<DeviceMgr> dvc_mgr_;
- std::unique_ptr<FunctionLibraryDefinition> fld_;
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
- FunctionLibraryRuntime* flr_ = nullptr;
-};
-
-Status InlineSymbolicGradient(const NodeDef& node, SymbolicGradientEnv* env,
+Status InlineSymbolicGradient(const NodeDef& node,
+ FunctionOptimizerContext* ctx,
GraphDef* inlined_graph) {
VLOG(2) << "Inline symbolic gradient: " << SummarizeNodeDef(node);
@@ -732,15 +698,15 @@ Status InlineSymbolicGradient(const NodeDef& node, SymbolicGradientEnv* env,
GraphConstructorOptions graph_ctor_opts;
graph_ctor_opts.allow_internal_ops = true;
graph_ctor_opts.expect_device_spec = false;
- Graph graph(env->function_library());
+ Graph graph(ctx->function_library());
TF_RETURN_IF_ERROR(
ConvertGraphDefToGraph(graph_ctor_opts, graph_def, &graph));
// Recursively inline the functions until there is nothing more to inline. We
// should at least expand one function.
int counter = 0;
- while (counter < 50 &&
- ExpandInlineFunctions(env->function_library_runtime(), &graph)) {
+ while (counter < 50 && ExpandInlineFunctions(
+ ctx->mutable_function_library_runtime(), &graph)) {
++counter;
}
@@ -801,8 +767,6 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
FunctionOptimizerContext ctx(opt_level_, item);
- SymbolicGradientEnv env(item.graph.versions().producer(),
- item.graph.library());
bool inline_gradients = options_.enable_symbolic_gradient_inlining;
bool inline_func = options_.enable_function_inlining;
@@ -816,7 +780,7 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
string f_name = f_attr != nullptr ? f_attr->func().name() : "";
if (ctx.IsInlinedFunction(f_name)) {
- TF_RETURN_IF_ERROR(InlineSymbolicGradient(node, &env, optimized_graph));
+ TF_RETURN_IF_ERROR(InlineSymbolicGradient(node, &ctx, optimized_graph));
continue;
}
}
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
index 7d3520febc..490b337c3e 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
@@ -390,7 +390,7 @@ Status LoopInvariantNodeMotionOptimizer::Optimize() {
frame_children_[frame_ids[0]].insert(frame_ids[1]);
frame_parent_[frame_ids.back()] = frame_ids[frame_ids.size() - 2];
}
- if (frame_ids.size() >= 1) {
+ if (!frame_ids.empty()) {
frame_children_.insert(std::make_pair(frame_ids.back(), empty_set_));
if (node->op() == "LoopCond") {
if (loop_cond_.count(frame_ids.back())) {
@@ -409,7 +409,7 @@ Status LoopInvariantNodeMotionOptimizer::Optimize() {
}
for (auto it = frame_children_.begin(); it != frame_children_.end(); ++it) {
- if (it->second.size() == 0) {
+ if (it->second.empty()) {
worklist.push_back(it->first);
}
}
@@ -422,7 +422,7 @@ Status LoopInvariantNodeMotionOptimizer::Optimize() {
if (parent_it != frame_parent_.end()) {
int parent_id = parent_it->second;
frame_children_[parent_id].erase(frame_id);
- if (frame_children_[parent_id].size() == 0) {
+ if (frame_children_[parent_id].empty()) {
worklist.push_back(parent_id);
}
}
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 0c8e18d7ab..4435a8353b 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -24,11 +24,11 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
-#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
+#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
#include "tensorflow/core/grappler/utils/colocation.h"
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
@@ -78,6 +78,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("pruning", new ModelPruner());
MK_OPT("function", new FunctionOptimizer(cfg_.function_optimization()));
MK_OPT("constfold", new ConstantFolding(cpu_device_));
+ MK_OPT("shape", new ShapeOptimizer());
MK_OPT("layout", new LayoutOptimizer());
MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL));
MK_OPT("arithmetic", new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
@@ -107,6 +108,9 @@ Status MetaOptimizer::InitializeOptimizers(
optimizers->emplace_back(
new ConstantFolding(cfg_.constant_folding(), cpu_device_));
}
+ if (cfg_.shape_optimization() == RewriterConfig::ON) {
+ optimizers->emplace_back(new ShapeOptimizer());
+ }
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
optimizers->emplace_back(
new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
@@ -344,6 +348,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
cfg.layout_optimizer() != RewriterConfig::OFF ||
cfg.function_optimization() != RewriterConfig::OFF ||
cfg.constant_folding() != RewriterConfig::OFF ||
+ cfg.shape_optimization() == RewriterConfig::ON ||
cfg.arithmetic_optimization() != RewriterConfig::OFF ||
cfg.loop_optimization() != RewriterConfig::OFF ||
cfg.dependency_optimization() != RewriterConfig::OFF ||
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
new file mode 100644
index 0000000000..26c54df56b
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
@@ -0,0 +1,133 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
+
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
+
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace grappler {
+
+Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+
+ GraphProperties properties(item);
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ GraphView graph(optimized_graph);
+
+ // The product of all the dimensions in a tensor shape can be expressed more
+ // simply as the size of the tensor.
+ for (auto& node : *optimized_graph->mutable_node()) {
+ if (!IsShape(node)) {
+ continue;
+ }
+ for (GraphView::InputPort fanout :
+ graph.GetFanout(GraphView::OutputPort(&node, 0))) {
+ if (fanout.node->op() != "Prod") {
+ continue;
+ }
+ if (fanout.node->attr().count("keep_dims") != 0 &&
+ fanout.node->attr().at("keep_dims").b()) {
+ // Keeping the reduced dimensions won't result in a scalar, so we can't
+ // rewrite the whole expression directly as a Size operation.
+ continue;
+ }
+ const GraphView::OutputPort reduce_indices =
+ graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1));
+ const auto& prop =
+ properties.GetOutputProperties(reduce_indices.node->name());
+ if (prop.size() < reduce_indices.port_id) {
+ continue;
+ }
+ const TensorShapeProto& reduction_indices_shape =
+ prop[reduce_indices.port_id].shape();
+ if (NumCoefficients(reduction_indices_shape) == 1) {
+ const auto& input_props = properties.GetInputProperties(node.name());
+ if (input_props.size() != 1) {
+ continue;
+ }
+ // Rewrite the reduction of the shape dimensions as a Size operation.
+ const DataType type = input_props[0].dtype();
+ fanout.node->set_op("Size");
+ fanout.node->set_input(0, node.input(0));
+ fanout.node->set_input(1, AsControlDependency(node));
+ fanout.node->mutable_attr()->erase("Tidx");
+ fanout.node->mutable_attr()->erase("keep_dims");
+ (*fanout.node->mutable_attr())["out_type"] =
+ fanout.node->attr().at("T");
+ (*fanout.node->mutable_attr())["T"].set_type(type);
+ }
+ }
+ }
+ for (auto& node : *optimized_graph->mutable_node()) {
+ // Try to convert the ratio of 2 symbolic tensor sizes into a constant. This
+ // is possible whenever the symbolic dimensions in the numerator and
+ // denominator cancel each other.
+ if (node.op() == "Div") {
+ const GraphView::OutputPort input1 =
+ graph.GetRegularFanin(GraphView::InputPort(&node, 0));
+ const GraphView::OutputPort input2 =
+ graph.GetRegularFanin(GraphView::InputPort(&node, 1));
+ if (!IsSize(*input1.node) || !IsSize(*input2.node)) {
+ continue;
+ }
+ const auto& prop1 = properties.GetInputProperties(input1.node->name());
+ const auto& prop2 = properties.GetInputProperties(input2.node->name());
+ if (prop1.size() != 1 || prop2.size() != 1) {
+ continue;
+ }
+ const TensorShapeProto& shape1 = prop1[0].shape();
+ const TensorShapeProto& shape2 = prop2[0].shape();
+ int64 result = ComputeSizeRatio(shape1, shape2);
+ if (result >= 0) {
+ // Replace div with constant.
+ node.set_op("Const");
+ DataType dtype = node.attr().at("T").type();
+ node.mutable_attr()->erase("T");
+ (*node.mutable_attr())["dtype"].set_type(dtype);
+ TensorProto* t = (*node.mutable_attr())["value"].mutable_tensor();
+ t->set_dtype(dtype);
+ *t->mutable_tensor_shape() = TensorShapeProto();
+ if (dtype == DT_INT32) {
+ t->add_int_val(result);
+ } else {
+ t->add_int64_val(result);
+ }
+ node.set_input(0, AsControlDependency(node.input(0)));
+ node.set_input(1, AsControlDependency(node.input(1)));
+ }
+ }
+ }
+ return Status::OK();
+}
+
+void ShapeOptimizer::Feedback(Cluster* /*cluster*/,
+ const GrapplerItem& /*item*/,
+ const GraphDef& /*optimized_graph*/,
+ double /*result*/) {
+ // Nothing to do for LoopOptimizer.
+}
+
+} // end namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.h b/tensorflow/core/grappler/optimizers/shape_optimizer.h
new file mode 100644
index 0000000000..b7f84a1e5d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.h
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SHAPE_OPTIMIZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SHAPE_OPTIMIZER_H_
+
+#include <unordered_set>
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/frame.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Optimize TensorFlow subgraphs that operate on shape and shape related
+// information.
+class ShapeOptimizer : public GraphOptimizer {
+ public:
+ ShapeOptimizer() : opt_level_(RewriterConfig::ON) {}
+ explicit ShapeOptimizer(RewriterConfig::Toggle opt_level)
+ : opt_level_(opt_level) {}
+
+ ~ShapeOptimizer() override {}
+
+ string name() const override { return "shape_optimizer"; };
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override;
+
+ private:
+ RewriterConfig::Toggle opt_level_;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SHAPE_OPTIMIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer_test.cc b/tensorflow/core/grappler/optimizers/shape_optimizer_test.cc
new file mode 100644
index 0000000000..95a5eccd4f
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer_test.cc
@@ -0,0 +1,105 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class ShapeOptimizerTest : public GrapplerTest {};
+
+TEST_F(ShapeOptimizerTest, OptimizeShapeProduct) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 3.14f, {32, 16});
+ Output c = ops::Shape(s.WithOpName("c"), a);
+ Output d = ops::Const(s.WithOpName("d"), 0, {1});
+ ops::ReduceProd::Attrs attrs;
+ Output e = ops::ReduceProd(s.WithOpName("e"), c, d, attrs.KeepDims(false));
+ Output f = ops::ReduceProd(s.WithOpName("f"), c, d, attrs.KeepDims(true));
+
+ GrapplerItem item;
+ item.fetch = {"e", "f"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ ShapeOptimizer optimizer;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "e") {
+ found++;
+ EXPECT_EQ("Size", node.op());
+ EXPECT_EQ("a", node.input(0));
+ } else if (node.name() == "f") {
+ found++;
+ EXPECT_EQ("Prod", node.op());
+ EXPECT_EQ("c", node.input(0));
+ }
+ }
+ EXPECT_EQ(2, found);
+
+ auto tensors_actual = EvaluateNodes(output, item.fetch);
+ EXPECT_NEAR(tensors_expected[0].scalar<int>()(),
+ tensors_actual[0].scalar<int>()(), 0);
+ EXPECT_NEAR(tensors_expected[1].scalar<int>()(),
+ tensors_actual[1].scalar<int>()(), 0);
+}
+
+TEST_F(ShapeOptimizerTest, OptimizeShapeRatio) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 3.14f, {32, 32});
+ Output b = ops::Const(s.WithOpName("b"), 3.14f, {32, 16});
+ Output c = ops::Size(s.WithOpName("c"), a);
+ Output d = ops::Size(s.WithOpName("d"), b);
+ Output e = ops::Div(s.WithOpName("e"), c, d);
+
+ GrapplerItem item;
+ item.fetch = {"e"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ ShapeOptimizer optimizer;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "e") {
+ found++;
+ EXPECT_EQ("Const", node.op());
+ }
+ }
+ EXPECT_EQ(1, found);
+
+ auto tensors_actual = EvaluateNodes(output, item.fetch);
+ EXPECT_NEAR(tensors_expected[0].scalar<int>()(),
+ tensors_actual[0].scalar<int>()(), 0);
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc b/tensorflow/core/grappler/optimizers/symbolic_shapes.cc
index cfca2dc0d3..32e86f8290 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc
+++ b/tensorflow/core/grappler/optimizers/symbolic_shapes.cc
@@ -49,6 +49,27 @@ bool ShapeIsSymbolicallyDefined(const OpInfo::TensorProperties& properties) {
return ShapeIsSymbolicallyDefined(properties.shape());
}
+int Rank(const TensorShapeProto& shape) {
+ if (shape.unknown_rank()) {
+ return -1;
+ }
+ return shape.dim_size();
+}
+
+int64 NumCoefficients(const TensorShapeProto& shape) {
+ if (shape.unknown_rank()) {
+ return -1;
+ }
+ int64 num_coefficients = 1;
+ for (const auto& dim : shape.dim()) {
+ if (dim.size() < 0) {
+ return -1;
+ }
+ num_coefficients *= dim.size();
+ }
+ return num_coefficients;
+}
+
bool ShapesSymbolicallyEqual(const TensorShapeProto& left,
const TensorShapeProto& right) {
if (left.unknown_rank() || right.unknown_rank() ||
@@ -173,5 +194,44 @@ bool CompareSymbolicallyShapedTensorSizes(
return CompareSymbolicallyShapedTensorSizes(left.shape(), right.shape());
}
+int64 ComputeSizeRatio(const TensorShapeProto& numerator,
+ const TensorShapeProto& denominator) {
+ if (numerator.unknown_rank() || denominator.unknown_rank()) {
+ return -1;
+ }
+ std::multiset<int> symbolic_dims;
+ int64 num = 1;
+ for (const auto& dim : numerator.dim()) {
+ if (dim.size() == -1) {
+ return -1;
+ } else if (dim.size() < -1) {
+ symbolic_dims.insert(dim.size());
+ } else {
+ num *= dim.size();
+ }
+ }
+ int64 denom = 1;
+ for (const auto& dim : denominator.dim()) {
+ if (dim.size() == -1) {
+ return -1;
+ } else if (dim.size() < -1) {
+ auto it = symbolic_dims.find(dim.size());
+ if (it == symbolic_dims.end()) {
+ return -1;
+ }
+ symbolic_dims.erase(it);
+ } else {
+ denom *= dim.size();
+ }
+ }
+ if (denom == 0) {
+ return -1;
+ }
+ if (!symbolic_dims.empty()) {
+ return -1;
+ }
+ return num / denom;
+}
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.h b/tensorflow/core/grappler/optimizers/symbolic_shapes.h
index eb79bab314..38d7fbf090 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.h
+++ b/tensorflow/core/grappler/optimizers/symbolic_shapes.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
@@ -31,6 +32,14 @@ bool IsUnknown(const TensorShapeProto::Dim& dim);
bool ShapeIsSymbolicallyDefined(const TensorShapeProto& shape);
bool ShapeIsSymbolicallyDefined(const OpInfo::TensorProperties& properties);
+// Returns the rank of the shape ir -1 if unknown
+int Rank(const TensorShapeProto& shape);
+
+// Returns the number of coefficients in the shape or -1 if unknown.
+// TODO(bsteiner) Add a function that computes the minimum size of the tensor,
+// ie the size assuming all the symbolic dimensions take the value 1.
+int64 NumCoefficients(const TensorShapeProto& shape);
+
// Shapes are symbolically equal, if they have the same rank, they are known or
// symbolically defined, and have matching dimensions.
bool ShapesSymbolicallyEqual(const TensorShapeProto& left,
@@ -54,6 +63,11 @@ bool CompareSymbolicallyShapedTensorSizes(
const OpInfo::TensorProperties& left,
const OpInfo::TensorProperties& right);
+// Returns the ratio of the sizes of the 2 shapes if known statically, or -1
+// otherwise.
+int64 ComputeSizeRatio(const TensorShapeProto& numerator,
+ const TensorShapeProto& denominator);
+
} // namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc b/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc
index 5ef9f65925..5720fbd097 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc
+++ b/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc
@@ -90,6 +90,33 @@ TEST_F(SymbolicShapesTest, CompareSymbolicallyShapedTensorSizes) {
EXPECT_FALSE(MakeShape({-1, -1, 32}) < MakeShape({1, -1, 32}));
}
+TEST_F(SymbolicShapesTest, RankAndNumCoeff) {
+ EXPECT_EQ(2, Rank(MakeShape({32, 32})));
+ EXPECT_EQ(32 * 32, NumCoefficients(MakeShape({32, 32})));
+ EXPECT_EQ(2, Rank(MakeShape({-2, 32})));
+ EXPECT_EQ(-1, NumCoefficients(MakeShape({-2, 32})));
+ TensorShapeProto shape;
+ shape.set_unknown_rank(true);
+ EXPECT_EQ(-1, Rank(shape));
+ EXPECT_EQ(-1, NumCoefficients(shape));
+}
+
+TEST_F(SymbolicShapesTest, SizeRatio) {
+ EXPECT_EQ(16, ComputeSizeRatio(MakeShape({32, 32}), MakeShape({32, 2})));
+ EXPECT_EQ(16, ComputeSizeRatio(MakeShape({-2, 32}), MakeShape({-2, 2})));
+ EXPECT_EQ(16,
+ ComputeSizeRatio(MakeShape({-2, -2, 32}), MakeShape({-2, 2, -2})));
+ EXPECT_EQ(-1,
+ ComputeSizeRatio(MakeShape({-2, -2, 32}), MakeShape({-2, 2, 2})));
+ EXPECT_EQ(-1,
+ ComputeSizeRatio(MakeShape({-2, 2, 32}), MakeShape({-2, 2, -2})));
+ EXPECT_EQ(-1, ComputeSizeRatio(MakeShape({-2, -2}), MakeShape({-2, 2})));
+ EXPECT_EQ(-1, ComputeSizeRatio(MakeShape({-2, 32}), MakeShape({-2, -2})));
+ EXPECT_EQ(1, ComputeSizeRatio(MakeShape({-2, -3}), MakeShape({-3, -2})));
+ EXPECT_EQ(-1, ComputeSizeRatio(MakeShape({-1, 32}), MakeShape({-2, 2})));
+ EXPECT_EQ(-1, ComputeSizeRatio(MakeShape({-1, 32}), MakeShape({-2, 0})));
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index 34603f9869..5a5dc47fa0 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -380,16 +380,6 @@ GrapplerFunctionItem& GrapplerFunctionItem::SwapFunctionBody(GraphDef&& other) {
return *this;
}
-std::vector<string> OutputTensors(const GrapplerFunctionItem& item) {
- std::vector<string> output_tensors;
- for (const OutputArgExpansion& output : item.outputs()) {
- for (const string& tensor : output.output_tensors) {
- output_tensors.push_back(tensor);
- }
- }
- return output_tensors;
-}
-
bool HasParametrizedType(const FunctionDef& func) {
const auto is_type_parametrized = [](const OpDef::ArgDef& arg) {
return !arg.type_attr().empty() || !arg.number_attr().empty() ||
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index 4641bf5252..6227daa71b 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -176,9 +176,6 @@ class GrapplerFunctionItem : public GrapplerItem {
bool is_stateful_;
};
-// Return all output tensors referenced by item output args.
-std::vector<string> OutputTensors(const GrapplerFunctionItem& item);
-
// Check if function input/output types are fully defined only at instantiation
// time (parametrized by it's instantiation node).
bool HasParametrizedType(const FunctionDef& func);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 3fb03cd5bd..a0fe64113e 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2032,9 +2032,12 @@ tf_kernel_library(
name = "functional_ops",
prefix = "functional_ops",
deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:functional_ops_op_lib",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//third_party/eigen3",
],
)
@@ -4249,6 +4252,7 @@ cc_library(
":as_string_op",
":base64_ops",
":reduce_join_op",
+ ":regex_full_match_op",
":regex_replace_op",
":string_join_op",
":string_split_op",
@@ -4286,6 +4290,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "regex_full_match_op",
+ prefix = "regex_full_match_op",
+ deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"],
+)
+
+tf_kernel_library(
name = "regex_replace_op",
prefix = "regex_replace_op",
deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"],
diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h
index a1c03f9918..475bda848d 100644
--- a/tensorflow/core/kernels/batch_matmul_op_impl.h
+++ b/tensorflow/core/kernels/batch_matmul_op_impl.h
@@ -329,6 +329,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
c_ptrs.push_back(&c_device_memory.back());
}
+ typedef Scalar Coefficient;
+
// Cublas does
// C = A x B
// where A, B and C are assumed to be in column major.
@@ -352,9 +354,9 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
bool blas_launch_status =
stream
->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m,
- static_cast<Scalar>(1.0), *(a_ptrs[0]),
+ static_cast<Coefficient>(1.0), *(a_ptrs[0]),
adj_x ? m : k, *(b_ptrs[0]), 1,
- static_cast<Scalar>(0.0), c_ptrs[0], 1)
+ static_cast<Coefficient>(0.0), c_ptrs[0], 1)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
@@ -366,9 +368,9 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
bool blas_launch_status =
stream
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
- static_cast<Scalar>(1.0), *(b_ptrs[0]),
+ static_cast<Coefficient>(1.0), *(b_ptrs[0]),
adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
- static_cast<Scalar>(0.0), c_ptrs[0], n)
+ static_cast<Coefficient>(0.0), c_ptrs[0], n)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
@@ -383,8 +385,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
stream
->ThenBlasGemmBatchedWithScratch(
blas_transpose_b, blas_transpose_a, n, m, k,
- static_cast<Scalar>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
- adj_x ? m : k, static_cast<Scalar>(0.0), c_ptrs, n,
+ static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
+ adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
batch_size, &scratch_allocator)
.ok();
if (!blas_launch_status) {
@@ -398,6 +400,98 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
}
};
+template <>
+struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
+ static void Launch(OpKernelContext* context, const Tensor& in_x,
+ const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
+ typedef Eigen::half Scalar;
+ constexpr perftools::gputools::blas::Transpose kTranspose =
+ is_complex<Scalar>::value
+ ? perftools::gputools::blas::Transpose::kConjugateTranspose
+ : perftools::gputools::blas::Transpose::kTranspose;
+ perftools::gputools::blas::Transpose trans[] = {
+ perftools::gputools::blas::Transpose::kNoTranspose, kTranspose};
+ const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
+ const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
+ const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
+ const uint64 batch_size = in_x.dim_size(0);
+ auto blas_transpose_a = trans[adj_x];
+ auto blas_transpose_b = trans[adj_y];
+
+ auto* stream = context->op_device_context()->stream();
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+ typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
+ std::vector<DeviceMemoryType> a_device_memory;
+ std::vector<DeviceMemoryType> b_device_memory;
+ std::vector<DeviceMemoryType> c_device_memory;
+ std::vector<DeviceMemoryType*> a_ptrs;
+ std::vector<DeviceMemoryType*> b_ptrs;
+ std::vector<DeviceMemoryType*> c_ptrs;
+ a_device_memory.reserve(batch_size);
+ b_device_memory.reserve(batch_size);
+ c_device_memory.reserve(batch_size);
+ a_ptrs.reserve(batch_size);
+ b_ptrs.reserve(batch_size);
+ c_ptrs.reserve(batch_size);
+ auto* a_base_ptr = in_x.template flat<Scalar>().data();
+ auto* b_base_ptr = in_y.template flat<Scalar>().data();
+ auto* c_base_ptr = out->template flat<Scalar>().data();
+ for (int64 i = 0; i < batch_size; ++i) {
+ a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
+ b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
+ c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
+ a_ptrs.push_back(&a_device_memory.back());
+ b_ptrs.push_back(&b_device_memory.back());
+ c_ptrs.push_back(&c_device_memory.back());
+ }
+
+ typedef float Coefficient;
+
+ // Cublas does
+ // C = A x B
+ // where A, B and C are assumed to be in column major.
+ // We want the output to be in row-major, so we can compute
+ // C' = B' x A', where ' stands for transpose (not adjoint).
+ // TODO(yangzihao): Choose the best of the three strategies using autotune.
+ if (batch_size == 1) {
+ // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
+ // overhead of the scratch allocator and the batch interface.
+ // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
+ static_cast<Coefficient>(1.0), *(b_ptrs[0]),
+ adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
+ static_cast<Coefficient>(0.0), c_ptrs[0], n)
+ .ok();
+ if (!blas_launch_status) {
+ context->SetStatus(errors::Internal(
+ "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
+ ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
+ ", k=", k));
+ }
+ } else {
+ CublasScratchAllocator scratch_allocator(context);
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemmBatchedWithScratch(
+ blas_transpose_b, blas_transpose_a, n, m, k,
+ static_cast<Coefficient>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
+ adj_x ? m : k, static_cast<Coefficient>(0.0), c_ptrs, n,
+ batch_size, &scratch_allocator)
+ .ok();
+ if (!blas_launch_status) {
+ context->SetStatus(
+ errors::Internal("Blas xGEMMBatched launch failed : a.shape=",
+ in_x.shape().DebugString(), ", b.shape=",
+ in_y.shape().DebugString(), ", m=", m, ", n=", n,
+ ", k=", k, ", batch_size=", batch_size));
+ }
+ }
+ }
+};
+
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc
index 7e1e2aa4ec..2bb22bbd4f 100644
--- a/tensorflow/core/kernels/batch_matmul_op_real.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_real.cc
@@ -15,6 +15,10 @@ limitations under the License.
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
+#if GOOGLE_CUDA
+#include "cuda/include/cuda.h"
+#endif // GOOGLE_CUDA
+
namespace tensorflow {
#if !defined(INTEL_MKL)
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index aca75176a5..bdd08222d4 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -404,10 +404,9 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
// image ('work_unit_size').
// TODO(andydavis)
- // *) Get L3 cache size from device at runtime (30MB is from ivybridge).
// *) Consider reducing 'target_working_set_size' if L3 is shared by
// other concurrently running tensorflow ops.
- const size_t target_working_set_size = (30LL << 20) / sizeof(T);
+ const size_t target_working_set_size = Eigen::l3CacheSize() / sizeof(T);
const size_t size_A = output_image_size * filter_total_size;
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index 63a775afa8..95301b170f 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -420,9 +420,8 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
const int output_image_size =
dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size;
- // TODO(andydavis) Get L2/L3 cache sizes from device.
- const size_t l2_cache_size = 256LL << 10;
- const size_t l3_cache_size = 30LL << 20;
+ const size_t l2_cache_size = Eigen::l2CacheSize();
+ const size_t l3_cache_size = Eigen::l3CacheSize();
// Use L3 cache size as target working set size.
const size_t target_working_set_size = l3_cache_size / sizeof(T);
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 9edc6d416e..980b1063de 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -195,8 +195,8 @@ class Conv3DBackpropInputOp : public OpKernel {
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
- OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
- input_sizes.vec<int32>(), &input_shape));
+ // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes.
+ OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
} else {
input_shape = context->input(0).shape();
}
diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc
index 54ef9c6fb4..99d01b4db6 100644
--- a/tensorflow/core/kernels/crop_and_resize_op.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op.cc
@@ -110,10 +110,10 @@ class CropAndResizeOp : public AsyncOpKernel {
public:
explicit CropAndResizeOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {
- string method;
- OP_REQUIRES_OK(context, context->GetAttr("method", &method));
- OP_REQUIRES(context, method == "bilinear",
- errors::InvalidArgument("method must be 'bilinear'", method));
+ OP_REQUIRES_OK(context, context->GetAttr("method", &method_));
+ OP_REQUIRES(context, method_ == "bilinear" || method_ == "nearest",
+ errors::InvalidArgument(
+ "method must be 'bilinear' or 'nearest'", method_));
OP_REQUIRES_OK(context, context->GetAttr("extrapolation_value",
&extrapolation_value_));
}
@@ -178,7 +178,7 @@ class CropAndResizeOp : public AsyncOpKernel {
const Tensor& box_index = context->input(2);
const bool status = functor::CropAndResize<Device, T>()(
context, image.tensor<T, 4>(), boxes.tensor<float, 2>(),
- box_index.tensor<int32, 1>(), extrapolation_value_,
+ box_index.tensor<int32, 1>(), method_, extrapolation_value_,
output->tensor<float, 4>());
if (!status) {
context->SetStatus(
@@ -193,6 +193,7 @@ class CropAndResizeOp : public AsyncOpKernel {
private:
float extrapolation_value_;
+ string method_;
};
// Partial specialization of CropAndResize functor for a CPUDevice.
@@ -203,7 +204,7 @@ struct CropAndResize<CPUDevice, T> {
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_index,
- float extrapolation_value,
+ const string& method_name, float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) {
const int batch_size = image.dimension(0);
const int image_height = image.dimension(1);
@@ -247,37 +248,57 @@ struct CropAndResize<CPUDevice, T> {
}
continue;
}
- const int top_y_index = floorf(in_y);
- const int bottom_y_index = ceilf(in_y);
- const float y_lerp = in_y - top_y_index;
-
- for (int x = 0; x < crop_width; ++x) {
- const float in_x = (crop_width > 1)
- ? x1 * (image_width - 1) + x * width_scale
- : 0.5 * (x1 + x2) * (image_width - 1);
- if (in_x < 0 || in_x > image_width - 1) {
+ if (method_name == "bilinear") {
+ const int top_y_index = floorf(in_y);
+ const int bottom_y_index = ceilf(in_y);
+ const float y_lerp = in_y - top_y_index;
+
+ for (int x = 0; x < crop_width; ++x) {
+ const float in_x = (crop_width > 1)
+ ? x1 * (image_width - 1) + x * width_scale
+ : 0.5 * (x1 + x2) * (image_width - 1);
+ if (in_x < 0 || in_x > image_width - 1) {
+ for (int d = 0; d < depth; ++d) {
+ crops(b, y, x, d) = extrapolation_value;
+ }
+ continue;
+ }
+ const int left_x_index = floorf(in_x);
+ const int right_x_index = ceilf(in_x);
+ const float x_lerp = in_x - left_x_index;
+
for (int d = 0; d < depth; ++d) {
- crops(b, y, x, d) = extrapolation_value;
+ const float top_left(static_cast<float>(
+ image(b_in, top_y_index, left_x_index, d)));
+ const float top_right(static_cast<float>(
+ image(b_in, top_y_index, right_x_index, d)));
+ const float bottom_left(static_cast<float>(
+ image(b_in, bottom_y_index, left_x_index, d)));
+ const float bottom_right(static_cast<float>(
+ image(b_in, bottom_y_index, right_x_index, d)));
+ const float top = top_left + (top_right - top_left) * x_lerp;
+ const float bottom =
+ bottom_left + (bottom_right - bottom_left) * x_lerp;
+ crops(b, y, x, d) = top + (bottom - top) * y_lerp;
}
- continue;
}
- const int left_x_index = floorf(in_x);
- const int right_x_index = ceilf(in_x);
- const float x_lerp = in_x - left_x_index;
-
- for (int d = 0; d < depth; ++d) {
- const float top_left(static_cast<float>(
- image(b_in, top_y_index, left_x_index, d)));
- const float top_right(static_cast<float>(
- image(b_in, top_y_index, right_x_index, d)));
- const float bottom_left(static_cast<float>(
- image(b_in, bottom_y_index, left_x_index, d)));
- const float bottom_right(static_cast<float>(
- image(b_in, bottom_y_index, right_x_index, d)));
- const float top = top_left + (top_right - top_left) * x_lerp;
- const float bottom =
- bottom_left + (bottom_right - bottom_left) * x_lerp;
- crops(b, y, x, d) = top + (bottom - top) * y_lerp;
+ } else { // method == "nearest"
+ for (int x = 0; x < crop_width; ++x) {
+ const float in_x = (crop_width > 1)
+ ? x1 * (image_width - 1) + x * width_scale
+ : 0.5 * (x1 + x2) * (image_width - 1);
+ if (in_x < 0 || in_x > image_width - 1) {
+ for (int d = 0; d < depth; ++d) {
+ crops(b, y, x, d) = extrapolation_value;
+ }
+ continue;
+ }
+ const int closest_x_index = roundf(in_x);
+ const int closest_y_index = roundf(in_y);
+ for (int d = 0; d < depth; ++d) {
+ crops(b, y, x, d) = static_cast<float>(
+ image(b_in, closest_y_index, closest_x_index, d));
+ }
}
}
}
@@ -285,12 +306,17 @@ struct CropAndResize<CPUDevice, T> {
};
// A rough estimation of the cost for each cropped box.
- const double cost_per_pixel =
+ double cost_per_pixel =
depth * (Eigen::TensorOpCost::AddCost<float>() * 6 +
Eigen::TensorOpCost::MulCost<float>() * 3 +
Eigen::TensorOpCost::CastCost<T, float>() * 4) +
(Eigen::TensorOpCost::AddCost<float>() * 2 +
Eigen::TensorOpCost::AddCost<float>() * 3);
+ if (method_name == "nearest") {
+ cost_per_pixel = depth * Eigen::TensorOpCost::CastCost<T, float>() +
+ Eigen::TensorOpCost::AddCost<float>() * 4 +
+ Eigen::TensorOpCost::MulCost<float>() * 4;
+ }
const double cost_per_box = crop_height * crop_width * cost_per_pixel;
const DeviceBase::CpuWorkerThreads& worker_threads =
@@ -309,10 +335,10 @@ class CropAndResizeGradImageOp : public AsyncOpKernel {
public:
explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {
- string method;
- OP_REQUIRES_OK(context, context->GetAttr("method", &method));
- OP_REQUIRES(context, method == "bilinear",
- errors::InvalidArgument("method must be 'bilinear'", method));
+ OP_REQUIRES_OK(context, context->GetAttr("method", &method_));
+ OP_REQUIRES(context, method_ == "bilinear" || method_ == "nearest",
+ errors::InvalidArgument(
+ "method must be 'bilinear' or 'nearest'", method_));
}
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
@@ -372,14 +398,14 @@ class CropAndResizeGradImageOp : public AsyncOpKernel {
&output),
done);
- auto compute_callback = [context, output]() {
+ auto compute_callback = [this, context, output]() {
const Tensor& grads = context->input(0);
const Tensor& boxes = context->input(1);
const Tensor& box_index = context->input(2);
const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
context->eigen_device<Device>(), grads.tensor<float, 4>(),
boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
- output->tensor<T, 4>());
+ output->tensor<T, 4>(), method_);
if (!status) {
context->SetStatus(errors::Internal(
"Failed launch CropAndResizeBackpropImage kernel."));
@@ -390,6 +416,9 @@ class CropAndResizeGradImageOp : public AsyncOpKernel {
batch_size, std::move(compute_callback),
std::move(done));
}
+
+ private:
+ string method_;
};
// Partial specialization of CropAndResizeBackpropImage functor for a CPUDevice.
@@ -400,7 +429,8 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_index,
- typename TTypes<T, 4>::Tensor grads_image) {
+ typename TTypes<T, 4>::Tensor grads_image,
+ const string& method_name) {
const int batch_size = grads_image.dimension(0);
const int image_height = grads_image.dimension(1);
const int image_width = grads_image.dimension(2);
@@ -448,21 +478,30 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
if (in_x < 0 || in_x > image_width - 1) {
continue;
}
- const int left_x_index = floorf(in_x);
- const int right_x_index = ceilf(in_x);
- const float x_lerp = in_x - left_x_index;
+ if (method_name == "bilinear") {
+ const int left_x_index = floorf(in_x);
+ const int right_x_index = ceilf(in_x);
+ const float x_lerp = in_x - left_x_index;
- for (int d = 0; d < depth; ++d) {
- const float dtop = (1 - y_lerp) * grads(b, y, x, d);
- grads_image(b_in, top_y_index, left_x_index, d) +=
- static_cast<T>((1 - x_lerp) * dtop);
- grads_image(b_in, top_y_index, right_x_index, d) +=
- static_cast<T>(x_lerp * dtop);
- const float dbottom = y_lerp * grads(b, y, x, d);
- grads_image(b_in, bottom_y_index, left_x_index, d) +=
- static_cast<T>((1 - x_lerp) * dbottom);
- grads_image(b_in, bottom_y_index, right_x_index, d) +=
- static_cast<T>(x_lerp * dbottom);
+ for (int d = 0; d < depth; ++d) {
+ const float dtop = (1 - y_lerp) * grads(b, y, x, d);
+ grads_image(b_in, top_y_index, left_x_index, d) +=
+ static_cast<T>((1 - x_lerp) * dtop);
+ grads_image(b_in, top_y_index, right_x_index, d) +=
+ static_cast<T>(x_lerp * dtop);
+ const float dbottom = y_lerp * grads(b, y, x, d);
+ grads_image(b_in, bottom_y_index, left_x_index, d) +=
+ static_cast<T>((1 - x_lerp) * dbottom);
+ grads_image(b_in, bottom_y_index, right_x_index, d) +=
+ static_cast<T>(x_lerp * dbottom);
+ }
+ } else { // method_name == "nearest"
+ for (int d = 0; d < depth; ++d) {
+ int closest_x_index = roundf(in_x);
+ int closest_y_index = roundf(in_y);
+ grads_image(b_in, closest_y_index, closest_x_index, d) +=
+ static_cast<T>(grads(b, y, x, d));
+ }
}
}
}
diff --git a/tensorflow/core/kernels/crop_and_resize_op.h b/tensorflow/core/kernels/crop_and_resize_op.h
index b6b1dbd7b0..61dc3f941f 100644
--- a/tensorflow/core/kernels/crop_and_resize_op.h
+++ b/tensorflow/core/kernels/crop_and_resize_op.h
@@ -31,7 +31,7 @@ struct CropAndResize {
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
- float extrapolation_value,
+ string method_name, float extrapolation_value,
typename TTypes<float, 4>::Tensor crops);
};
@@ -41,7 +41,8 @@ struct CropAndResizeBackpropImage {
bool operator()(const Device& d, typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
- typename TTypes<T, 4>::Tensor grads_image);
+ typename TTypes<T, 4>::Tensor grads_image,
+ const string& method_name);
};
template <typename Device, typename T>
diff --git a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
index d12787d524..8ab08fb93a 100644
--- a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
@@ -32,11 +32,16 @@ typedef Eigen::GpuDevice GPUDevice;
namespace {
+enum InterpolationMethod {
+ BILINEAR = 0,
+ NEAREST = 1,
+};
+
template <typename T>
__global__ void CropAndResizeKernel(
const int32 nthreads, const T* image_ptr, const float* boxes_ptr,
const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
- int image_width, int crop_height, int crop_width, int depth,
+ int image_width, int crop_height, int crop_width, int depth, int method_id,
float extrapolation_value, float* crops_ptr) {
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
@@ -80,37 +85,47 @@ __global__ void CropAndResizeKernel(
continue;
}
- const int top_y_index = floorf(in_y);
- const int bottom_y_index = ceilf(in_y);
- const float y_lerp = in_y - top_y_index;
-
- const int left_x_index = floorf(in_x);
- const int right_x_index = ceilf(in_x);
- const float x_lerp = in_x - left_x_index;
-
- const float top_left(static_cast<float>(
- image_ptr[((b_in * image_height + top_y_index) * image_width +
- left_x_index) *
- depth +
- d]));
- const float top_right(static_cast<float>(
- image_ptr[((b_in * image_height + top_y_index) * image_width +
- right_x_index) *
- depth +
- d]));
- const float bottom_left(static_cast<float>(
- image_ptr[((b_in * image_height + bottom_y_index) * image_width +
- left_x_index) *
- depth +
- d]));
- const float bottom_right(static_cast<float>(
- image_ptr[((b_in * image_height + bottom_y_index) * image_width +
- right_x_index) *
- depth +
- d]));
- const float top = top_left + (top_right - top_left) * x_lerp;
- const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
- crops_ptr[out_idx] = top + (bottom - top) * y_lerp;
+ if (method_id == BILINEAR) {
+ const int top_y_index = floorf(in_y);
+ const int bottom_y_index = ceilf(in_y);
+ const float y_lerp = in_y - top_y_index;
+
+ const int left_x_index = floorf(in_x);
+ const int right_x_index = ceilf(in_x);
+ const float x_lerp = in_x - left_x_index;
+
+ const float top_left(static_cast<float>(
+ image_ptr[((b_in * image_height + top_y_index) * image_width +
+ left_x_index) *
+ depth +
+ d]));
+ const float top_right(static_cast<float>(
+ image_ptr[((b_in * image_height + top_y_index) * image_width +
+ right_x_index) *
+ depth +
+ d]));
+ const float bottom_left(static_cast<float>(
+ image_ptr[((b_in * image_height + bottom_y_index) * image_width +
+ left_x_index) *
+ depth +
+ d]));
+ const float bottom_right(static_cast<float>(
+ image_ptr[((b_in * image_height + bottom_y_index) * image_width +
+ right_x_index) *
+ depth +
+ d]));
+ const float top = top_left + (top_right - top_left) * x_lerp;
+ const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
+ crops_ptr[out_idx] = top + (bottom - top) * y_lerp;
+ } else { // method_id == kMethodNearestId
+ const int closest_x_index = roundf(in_x);
+ const int closest_y_index = roundf(in_y);
+ crops_ptr[out_idx] = static_cast<float>(
+ image_ptr[((b_in * image_height + closest_y_index) * image_width +
+ closest_x_index) *
+ depth +
+ d]);
+ }
}
}
@@ -119,7 +134,7 @@ __global__ void CropAndResizeBackpropImageKernel(
const int32 nthreads, const float* grads_ptr, const float* boxes_ptr,
const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
int image_width, int crop_height, int crop_width, int depth,
- T* grads_image_ptr) {
+ T* grads_image_ptr, int method_id) {
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = d + depth * (w + crop_width * (h + crop_height * b))
int idx = out_idx;
@@ -160,41 +175,52 @@ __global__ void CropAndResizeBackpropImageKernel(
continue;
}
- const int top_y_index = floorf(in_y);
- const int bottom_y_index = ceilf(in_y);
- const float y_lerp = in_y - top_y_index;
-
- const int left_x_index = floorf(in_x);
- const int right_x_index = ceilf(in_x);
- const float x_lerp = in_x - left_x_index;
-
- const float dtop = (1 - y_lerp) * grads_ptr[out_idx];
- CudaAtomicAdd(
- grads_image_ptr +
- ((b_in * image_height + top_y_index) * image_width + left_x_index) *
- depth +
- d,
- static_cast<T>((1 - x_lerp) * dtop));
- CudaAtomicAdd(grads_image_ptr +
- ((b_in * image_height + top_y_index) * image_width +
- right_x_index) *
- depth +
- d,
- static_cast<T>(x_lerp * dtop));
-
- const float dbottom = y_lerp * grads_ptr[out_idx];
- CudaAtomicAdd(grads_image_ptr +
- ((b_in * image_height + bottom_y_index) * image_width +
- left_x_index) *
- depth +
- d,
- static_cast<T>((1 - x_lerp) * dbottom));
- CudaAtomicAdd(grads_image_ptr +
- ((b_in * image_height + bottom_y_index) * image_width +
- right_x_index) *
- depth +
- d,
- static_cast<T>(x_lerp * dbottom));
+ if (method_id == BILINEAR) {
+ const int top_y_index = floorf(in_y);
+ const int bottom_y_index = ceilf(in_y);
+ const float y_lerp = in_y - top_y_index;
+
+ const int left_x_index = floorf(in_x);
+ const int right_x_index = ceilf(in_x);
+ const float x_lerp = in_x - left_x_index;
+
+ const float dtop = (1 - y_lerp) * grads_ptr[out_idx];
+ CudaAtomicAdd(grads_image_ptr +
+ ((b_in * image_height + top_y_index) * image_width +
+ left_x_index) *
+ depth +
+ d,
+ static_cast<T>((1 - x_lerp) * dtop));
+ CudaAtomicAdd(grads_image_ptr +
+ ((b_in * image_height + top_y_index) * image_width +
+ right_x_index) *
+ depth +
+ d,
+ static_cast<T>(x_lerp * dtop));
+
+ const float dbottom = y_lerp * grads_ptr[out_idx];
+ CudaAtomicAdd(grads_image_ptr +
+ ((b_in * image_height + bottom_y_index) * image_width +
+ left_x_index) *
+ depth +
+ d,
+ static_cast<T>((1 - x_lerp) * dbottom));
+ CudaAtomicAdd(grads_image_ptr +
+ ((b_in * image_height + bottom_y_index) * image_width +
+ right_x_index) *
+ depth +
+ d,
+ static_cast<T>(x_lerp * dbottom));
+ } else { // method_id == NEAREST
+ const int closest_x_index = roundf(in_x);
+ const int closest_y_index = roundf(in_y);
+ CudaAtomicAdd(grads_image_ptr +
+ ((b_in * image_height + closest_y_index) * image_width +
+ closest_x_index) *
+ depth +
+ d,
+ static_cast<T>(grads_ptr[out_idx]));
+ }
}
}
@@ -324,7 +350,7 @@ struct CropAndResize<GPUDevice, T> {
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
- float extrapolation_value,
+ string method_name, float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) {
const int batch = image.dimension(0);
const int image_height = image.dimension(1);
@@ -338,13 +364,19 @@ struct CropAndResize<GPUDevice, T> {
const int total_count = num_boxes * crop_height * crop_width * depth;
const GPUDevice& d = context->eigen_device<GPUDevice>();
+ InterpolationMethod method = BILINEAR;
+ if (method_name == "nearest") {
+ method = NEAREST;
+ }
+
if (total_count > 0) {
CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
CropAndResizeKernel<<<config.block_count, config.thread_per_block, 0,
d.stream()>>>(
config.virtual_thread_count, image.data(), boxes.data(),
box_ind.data(), num_boxes, batch, image_height, image_width,
- crop_height, crop_width, depth, extrapolation_value, crops.data());
+ crop_height, crop_width, depth, method, extrapolation_value,
+ crops.data());
}
return d.ok();
}
@@ -356,7 +388,8 @@ struct CropAndResizeBackpropImage<GPUDevice, T> {
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind,
- typename TTypes<T, 4>::Tensor grads_image) {
+ typename TTypes<T, 4>::Tensor grads_image,
+ const string& method_name) {
const int batch = grads_image.dimension(0);
const int image_height = grads_image.dimension(1);
const int image_width = grads_image.dimension(2);
@@ -377,6 +410,12 @@ struct CropAndResizeBackpropImage<GPUDevice, T> {
config.virtual_thread_count, grads_image.data());
}
+ // Configurate interpolation method.
+ InterpolationMethod method = BILINEAR;
+ if (method_name == "nearest") {
+ method = NEAREST;
+ }
+
// Accumulate.
total_count = num_boxes * crop_height * crop_width * depth;
if (total_count > 0) {
@@ -385,7 +424,7 @@ struct CropAndResizeBackpropImage<GPUDevice, T> {
config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, grads.data(), boxes.data(),
box_ind.data(), num_boxes, batch, image_height, image_width,
- crop_height, crop_width, depth, grads_image.data());
+ crop_height, crop_width, depth, grads_image.data(), method);
}
return d.ok();
}
diff --git a/tensorflow/core/kernels/crop_and_resize_op_test.cc b/tensorflow/core/kernels/crop_and_resize_op_test.cc
index 709082e799..6921020d09 100644
--- a/tensorflow/core/kernels/crop_and_resize_op_test.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op_test.cc
@@ -34,13 +34,14 @@ namespace tensorflow {
class CropAndResizeOpTest : public OpsTestBase {
protected:
template <typename T>
- void MakeOp(float extrapolation_value) {
+ void MakeOp(float extrapolation_value, const string& method) {
TF_EXPECT_OK(NodeDefBuilder("crop_and_resize_op", "CropAndResize")
.Input(FakeInput(DataTypeToEnum<T>::value))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_INT32))
.Input(FakeInput(DT_INT32))
.Attr("extrapolation_value", extrapolation_value)
+ .Attr("method", method)
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
}
@@ -48,7 +49,7 @@ class CropAndResizeOpTest : public OpsTestBase {
#define REGISTER_TEST(T) \
TEST_F(CropAndResizeOpTest, TestCropAndResize##T) { \
- MakeOp<T>(0); \
+ MakeOp<T>(0, "bilinear"); \
AddInputFromArray<T>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); \
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1}); \
AddInputFromArray<int32>(TensorShape({1}), {0}); \
@@ -58,6 +59,19 @@ class CropAndResizeOpTest : public OpsTestBase {
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 1})); \
test::FillValues<float>(&expected, {2.5}); \
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); \
+ } \
+ \
+ TEST_F(CropAndResizeOpTest, TestCropAndResize##T##nearest) { \
+ MakeOp<T>(0, "nearest"); \
+ AddInputFromArray<T>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); \
+ AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1}); \
+ AddInputFromArray<int32>(TensorShape({1}), {0}); \
+ AddInputFromArray<int32>(TensorShape({2}), {1, 1}); \
+ TF_ASSERT_OK(RunOpKernel()); \
+ \
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 1})); \
+ test::FillValues<float>(&expected, {4.0}); \
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0)); \
}
REGISTER_TEST(float)
@@ -72,7 +86,7 @@ REGISTER_TEST(int64)
#undef REGISTER_TEST
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1Uint8) {
- MakeOp<uint8>(0);
+ MakeOp<uint8>(0, "bilinear");
// Input:
// 1, 2
// 3, 4
@@ -87,8 +101,24 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1Uint8) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
+TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1Uint8NearestNeibor) {
+ MakeOp<uint8>(0, "nearest");
+ // Input:
+ // 1, 2
+ // 3, 4
+ AddInputFromArray<uint8>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+ AddInputFromArray<int32>(TensorShape({2}), {1, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 1}));
+ test::FillValues<float>(&expected, {4.0});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1Flipped) {
- MakeOp<float>(0);
+ MakeOp<float>(0, "bilinear");
// Input:
// 1, 2
// 3, 4
@@ -103,8 +133,24 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1Flipped) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
+TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1FlippedNearestNeighbor) {
+ MakeOp<float>(0, "nearest");
+ // Input:
+ // 1, 2
+ // 3, 4
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<float>(TensorShape({1, 4}), {1, 1, 0, 0});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+ AddInputFromArray<int32>(TensorShape({2}), {1, 1});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 1}));
+ test::FillValues<float>(&expected, {4.0});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3) {
- MakeOp<float>(0);
+ MakeOp<float>(0, "bilinear");
// Input:
// 1, 2
// 3, 4
@@ -124,8 +170,29 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
+TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3NearestNeighbor) {
+ MakeOp<float>(0, "nearest");
+ // Input:
+ // 1, 2
+ // 3, 4
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+ AddInputFromArray<int32>(TensorShape({2}), {3, 3});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 3, 3, 1}));
+ // clang-format off
+ test::FillValues<float>(&expected,
+ {1, 2, 2,
+ 3, 4, 4,
+ 3, 4, 4});
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Flipped) {
- MakeOp<float>(0);
+ MakeOp<float>(0, "bilinear");
// Input:
// 1, 2
// 3, 4
@@ -145,8 +212,54 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Flipped) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
+TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3FlippedNearestNeighbor) {
+ MakeOp<float>(0, "nearest");
+ // Input:
+ // 1, 2
+ // 3, 4
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<float>(TensorShape({1, 4}), {1, 1, 0, 0});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+ AddInputFromArray<int32>(TensorShape({2}), {3, 3});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 3, 3, 1}));
+ // clang-format off
+ test::FillValues<float>(&expected,
+ {4, 4, 3,
+ 4, 4, 3,
+ 2, 2, 1});
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2) {
- MakeOp<float>(0);
+ MakeOp<float>(0, "bilinear");
+ // Input:
+ // 1, 2, 3
+ // 4, 5, 6
+ // 7, 8, 9
+ AddInputFromArray<float>(TensorShape({1, 3, 3, 1}),
+ {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ AddInputFromArray<float>(TensorShape({2, 4}), {0, 0, 1, 1, 0, 0, 0.5, 0.5});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 0});
+ AddInputFromArray<int32>(TensorShape({2}), {2, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 2, 1}));
+
+ // clang-format off
+ test::FillValues<float>(&expected,
+ {1, 3,
+ 7, 9,
+ 1, 2,
+ 4, 5});
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2NearestNeighbor) {
+ MakeOp<float>(0, "nearest");
// Input:
// 1, 2, 3
// 4, 5, 6
@@ -171,7 +284,32 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2) {
}
TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2Flipped) {
- MakeOp<float>(0);
+ MakeOp<float>(0, "bilinear");
+ // Input:
+ // 1, 2, 3
+ // 4, 5, 6
+ // 7, 8, 9
+ AddInputFromArray<float>(TensorShape({1, 3, 3, 1}),
+ {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ AddInputFromArray<float>(TensorShape({2, 4}), {1, 1, 0, 0, 0.5, 0.5, 0, 0});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 0});
+ AddInputFromArray<int32>(TensorShape({2}), {2, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 2, 1}));
+
+ // clang-format off
+ test::FillValues<float>(&expected,
+ {9, 7,
+ 3, 1,
+ 5, 4,
+ 2, 1});
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2FlippedNearestNeighbor) {
+ MakeOp<float>(0, "nearest");
// Input:
// 1, 2, 3
// 4, 5, 6
@@ -197,7 +335,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2Flipped) {
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Extrapolated) {
const float v = -1;
- MakeOp<float>(v);
+ MakeOp<float>(v, "bilinear");
// Input:
// 1, 2
// 3, 4
@@ -218,7 +356,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Extrapolated) {
}
TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3NoCrop) {
- MakeOp<float>(0);
+ MakeOp<float>(0, "bilinear");
// Input:
// 1, 2
// 3, 4
@@ -236,7 +374,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3NoCrop) {
}
TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
- MakeOp<float>(0);
+ MakeOp<float>(0, "bilinear");
AddInputFromArray<float>(TensorShape({2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
AddInputFromArray<int32>(TensorShape({1}), {0});
@@ -248,7 +386,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
}
TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
- MakeOp<float>(0);
+ MakeOp<float>(0, "bilinear");
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
AddInputFromArray<int32>(TensorShape({2}), {0, 0});
@@ -261,7 +399,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
}
TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
- MakeOp<float>(0);
+ MakeOp<float>(0, "bilinear");
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
AddInputFromArray<int32>(TensorShape({1}), {1});
@@ -274,7 +412,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
}
TEST_F(CropAndResizeOpTest, TestWithSharding) {
- MakeOp<float>(0);
+ MakeOp<float>(0, "bilinear");
// Generate a relatively large input (999x999) so that sharding happens.
const int kLength = 999; // Length of the input. Must use an odd number.
const int kHalf = (kLength + 1) / 2; // Half size for the cropped result.
diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc
index 25560b7c28..04959df38d 100644
--- a/tensorflow/core/kernels/cudnn_rnn_ops.cc
+++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc
@@ -352,7 +352,7 @@ struct ToTFDataType<uint8> : std::integral_constant<DataType, DT_UINT8> {};
template <typename T>
class CudnnRnnAllocatorInTemp : public ScratchAllocator {
public:
- ~CudnnRnnAllocatorInTemp() = default;
+ ~CudnnRnnAllocatorInTemp() override = default;
explicit CudnnRnnAllocatorInTemp(OpKernelContext* context)
: context_(context) {}
@@ -571,7 +571,7 @@ Status ExtractForwardInput(OpKernelContext* context,
: 1;
if ((*input_h)->dims() != 3) {
- return errors::InvalidArgument("RNN input must be a 3-D vector.");
+ return errors::InvalidArgument("RNN input_h must be a 3-D vector.");
}
model_shapes->num_layers = (*input_h)->dim_size(0) / model_shapes->dir_count;
model_shapes->num_units = (*input_h)->dim_size(2);
@@ -1411,7 +1411,7 @@ class CudnnRNNForwardOpV2<GPUDevice, T>
CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context);
CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
status = DoForward<T>(
- context, *rnn_desc.get(), model_types(), model_shapes, input, input_h,
+ context, *rnn_desc, model_types(), model_shapes, input, input_h,
input_c, params, is_training(), output, output_h, output_c,
&reserve_space_allocator, &workspace_allocator, &fwd_profile_result);
if (!status.ok()) {
@@ -1422,12 +1422,11 @@ class CudnnRNNForwardOpV2<GPUDevice, T>
// Get reserve space from the forward pass.
Tensor reserve_space = reserve_space_allocator.get_allocated_tensor(0);
status = DoBackward<T>(
- context, *rnn_desc.get(), model_types(), model_shapes, input,
- input_h, input_c, params, output, output_h, output_c,
- &output_backprop, &output_h_backprop, &output_c_backprop,
- &reserve_space, &input_backprop, &input_h_backprop,
- &input_c_backprop, &params_backprop, &workspace_allocator,
- &bak_profile_result);
+ context, *rnn_desc, model_types(), model_shapes, input, input_h,
+ input_c, params, output, output_h, output_c, &output_backprop,
+ &output_h_backprop, &output_c_backprop, &reserve_space,
+ &input_backprop, &input_h_backprop, &input_c_backprop,
+ &params_backprop, &workspace_allocator, &bak_profile_result);
if (!status.ok()) {
continue;
}
diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc
index d50e9c9cf9..634b3c280f 100644
--- a/tensorflow/core/kernels/data/sql_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc
@@ -70,17 +70,19 @@ class SqlDatasetOp : public DatasetOpKernel {
"The set of supported databases is: {'sqlite'}.",
driver_name.c_str())));
- *output = new Dataset(driver_name, data_source_name, query, output_types_,
- output_shapes_);
+ *output = new Dataset(ctx, driver_name, data_source_name, query,
+ output_types_, output_shapes_);
}
private:
- class Dataset : public DatasetBase {
+ class Dataset : public GraphDatasetBase {
public:
- Dataset(const string& driver_name, const string& data_source_name,
- const string& query, const DataTypeVector& output_types,
+ Dataset(OpKernelContext* ctx, const string& driver_name,
+ const string& data_source_name, const string& query,
+ const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : driver_name_(driver_name),
+ : GraphDatasetBase(ctx),
+ driver_name_(driver_name),
data_source_name_(data_source_name),
query_(query),
output_types_(output_types),
@@ -102,6 +104,21 @@ class SqlDatasetOp : public DatasetOpKernel {
string DebugString() override { return "SqlDatasetOp::Dataset"; }
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* driver_name_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(driver_name_, &driver_name_node));
+ Node* data_source_name_node;
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(data_source_name_, &data_source_name_node));
+ Node* query_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(query_, &query_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {driver_name_node, data_source_name_node, query_node}, output));
+ return Status::OK();
+ }
+
private:
class Iterator : public DatasetIterator<Dataset> {
public:
@@ -121,22 +138,62 @@ class SqlDatasetOp : public DatasetOpKernel {
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (!query_connection_initialized_) {
- query_connection_initialized_ = true;
- query_connection_ = sql::DriverManager::CreateQueryConnection(
- dataset()->driver_name_);
- Status s = query_connection_->Open(dataset()->data_source_name_,
- dataset()->query_,
- dataset()->output_types_);
- if (!s.ok()) {
- LOG(WARNING) << "Failed to connect to database: " << s;
- return s;
- }
+ TF_RETURN_IF_ERROR(InitializeQueryConnection());
}
+ next_calls_++;
return query_connection_->GetNext(ctx, out_tensors, end_of_sequence);
}
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (query_connection_initialized_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("next_calls"), next_calls_));
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (reader->Contains(full_name("next_calls"))) {
+ TF_RETURN_IF_ERROR(InitializeQueryConnection());
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("next_calls"), &next_calls_));
+ int64 rem_next_calls = next_calls_;
+ std::vector<Tensor> out_tensors;
+ bool end_of_sequence = false;
+ while (rem_next_calls--) {
+ TF_RETURN_IF_ERROR(query_connection_->GetNext(ctx, &out_tensors,
+ &end_of_sequence));
+ out_tensors.clear();
+ }
+ } else {
+ query_connection_initialized_ = false;
+ }
+ return Status::OK();
+ }
+
private:
+ Status InitializeQueryConnection() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ query_connection_initialized_ = true;
+ query_connection_ =
+ sql::DriverManager::CreateQueryConnection(dataset()->driver_name_);
+ Status s = query_connection_->Open(dataset()->data_source_name_,
+ dataset()->query_,
+ dataset()->output_types_);
+ next_calls_ = 0;
+ if (!s.ok()) {
+ LOG(WARNING) << "Failed to connect to database: " << s;
+ return s;
+ }
+ return Status::OK();
+ }
+
mutex mu_;
+ // TODO(shivaniagrawal): explore ways to seek into a SQLite databases.
+ int64 next_calls_ GUARDED_BY(mu_) = 0;
std::unique_ptr<sql::QueryConnection> query_connection_ GUARDED_BY(mu_);
bool query_connection_initialized_ GUARDED_BY(mu_) = false;
};
diff --git a/tensorflow/core/kernels/deep_conv2d.cc b/tensorflow/core/kernels/deep_conv2d.cc
index 829155fb31..014684de64 100644
--- a/tensorflow/core/kernels/deep_conv2d.cc
+++ b/tensorflow/core/kernels/deep_conv2d.cc
@@ -393,9 +393,8 @@ struct TransformFilters {
// Calculate filter transform batch based on cache/filter sizes.
- // Cache budget (based on L2 cache size = 256KB).
- // TODO(andydavis) Read cache size from system.
- const int64 cache_size = (256LL << 10) / sizeof(T);
+ // Cache budget (based on L2 cache size).
+ const int64 cache_size = Eigen::l2CacheSize() / sizeof(T);
// Fixed cost.
const int64 filter_transform_matrix_size =
@@ -1017,9 +1016,8 @@ struct DeepConv2D<CPUDevice, T> {
const int64 filter_shard_size = filter_shards_row * filter_shards_col;
const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols;
- // Cache budget (based on L2 cache size = 256KB).
- // TODO(andydavis) Read cache size from the system.
- const int64 cache_size = (256LL << 10) / sizeof(T);
+ // Cache budget (based on L2 cache size).
+ const int64 cache_size = Eigen::l2CacheSize() / sizeof(T);
// Fixed costs.
const int64 tile_transform_matrix_size =
diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc
index 7afa21acb9..42a4832910 100644
--- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc
+++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc
@@ -1076,7 +1076,7 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
{1}, 0, filter_shape, &filter_backprop));
// If there is nothing to compute, return.
- if (filter_shape.num_elements() == 0) {
+ if (out_backprop.shape().num_elements() == 0) {
return;
}
diff --git a/tensorflow/core/kernels/dequantize_op.cc b/tensorflow/core/kernels/dequantize_op.cc
index 3f644a61bf..42fbf95cd3 100644
--- a/tensorflow/core/kernels/dequantize_op.cc
+++ b/tensorflow/core/kernels/dequantize_op.cc
@@ -96,27 +96,17 @@ class DequantizeOp : public OpKernel {
output);
}
} else if (mode_ == QUANTIZE_MODE_SCALED) {
- // The quantization logic for mode SCALED matches that of
- // QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3.
- static constexpr int num_bits = sizeof(T) * 8;
- const float max_abs = std::max(std::abs(min_range), std::abs(max_range));
- bool is_signed = std::is_signed<T>::value;
- // If it is signed, we try to keep 0.0 being 0 and drop one bucket. For
- // example, if it is 8 bits, we have the range [-127, 127]. So for input
- // range of [-x, x], the scale should be 254/(2*x).
- //
- // If it is unsigned and num_bits == 8, the range with 8 bits is [0, 255].
- // If the input range is [0, x], then the scale is x/255 instead of 254 as
- // in the case above.
- const int target_bits = is_signed ? (num_bits - 1) : num_bits;
- const float target_range =
- static_cast<float>((uint64_t{1} << target_bits) - 1);
- const float scale_factor = max_abs / target_range;
+ // TODO(pauldonnelly): Update QuantizeAndDequantizeV2 and
+ // QuantizeAndDequantizeV3 to match this SCALED mode again.
+ const float scale_factor =
+ std::numeric_limits<T>::min() == 0
+ ? (max_range / std::numeric_limits<T>::max())
+ : std::max(min_range / std::numeric_limits<T>::min(),
+ max_range / std::numeric_limits<T>::max());
float* out_ptr = output->flat<float>().data();
const T* in_ptr = input.flat<T>().data();
-
const int64 num_elements = input.NumElements();
- for (int i = 0; i < num_elements; ++i) {
+ for (int64 i = 0; i < num_elements; ++i) {
out_ptr[i] = static_cast<int>(in_ptr[i]) * scale_factor;
}
}
diff --git a/tensorflow/core/kernels/dequantize_op_test.cc b/tensorflow/core/kernels/dequantize_op_test.cc
index 9938eb61aa..63b18d7263 100644
--- a/tensorflow/core/kernels/dequantize_op_test.cc
+++ b/tensorflow/core/kernels/dequantize_op_test.cc
@@ -127,8 +127,8 @@ TEST_F(DequantizeOpTest, DequantizeMinCombinedQuint16) {
TEST_F(DequantizeOpTest, DequantizeScaledQuint8Zero) {
RunDequantizeScaledTest<quint8>(-255.0f, 127.0f, 0, 0.0);
}
-TEST_F(DequantizeOpTest, DequantizeScaledQuint8ScaleIdentity) {
- RunDequantizeScaledTest<quint8>(-255.0f, 127.0f, 127, 127.0);
+TEST_F(DequantizeOpTest, DequantizeScaledQuint8CheckIgnoresNegative) {
+ RunDequantizeScaledTest<quint8>(-512.0f, 255.0f, 255, 255.0);
}
TEST_F(DequantizeOpTest, DequantizeScaledQuint8ScaleDown) {
RunDequantizeScaledTest<quint8>(-1.0f, 2.0f, 255, 2.0);
@@ -144,7 +144,7 @@ TEST_F(DequantizeOpTest, DequantizeScaledQint8ScaleIdentity) {
RunDequantizeScaledTest<qint8>(-10.0f, 127.0f, -127, -127.0);
}
TEST_F(DequantizeOpTest, DequantizeScaledQint8ScaleDown) {
- RunDequantizeScaledTest<qint8>(-2.0f, 1.0f, -127, -2.0);
+ RunDequantizeScaledTest<qint8>(-2.0f, 1.0f, -128, -2.0);
}
TEST_F(DequantizeOpTest, DequantizeScaledQint8ScaleUp) {
RunDequantizeScaledTest<qint8>(-1.0f, 300.0f, 42, 99.212601);
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc
index 911aa3a78f..9ae04a1062 100644
--- a/tensorflow/core/kernels/functional_ops.cc
+++ b/tensorflow/core/kernels/functional_ops.cc
@@ -15,16 +15,16 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#if GOOGLE_CUDA
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/device_base.h"
+#endif
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#if GOOGLE_CUDA
-#include "tensorflow/stream_executor/stream.h"
-#endif // GOOGLE_CUDA
-
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -39,6 +39,21 @@ Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
}
+template <typename To, typename From> // use like this: down_cast<T*>(foo);
+inline To down_cast(From* f) { // so we only accept pointers
+ static_assert(
+ (std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
+ "target type not derived from source type");
+
+ // We skip the assert and hence the dynamic_cast if RTTI is disabled.
+#if !defined(__GNUC__) || defined(__GXX_RTTI)
+ // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
+ assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
+#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
+
+ return static_cast<To>(f);
+}
+
// If "t" is a scalar of a supported type, returns t != 0 in "*v".
Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) {
if (t.size() != 1) {
@@ -279,8 +294,46 @@ class WhileOp : public AsyncOpKernel {
}
void StartBody() {
+ Status s;
+ if (rets_.size() != 1) {
+ s = errors::InvalidArgument(
+ "Expected a single scalar return value from WhileOp cond, got ",
+ rets_.size(), " tensors.");
+ return Finish(s);
+ }
+ Tensor cond_t;
+#if GOOGLE_CUDA
+ const DeviceBase::GpuDeviceInfo* gpu_device_info =
+ ctx_->device()->tensorflow_gpu_device_info();
+ const bool is_hostmem_dtype =
+ rets_[0].dtype() == DT_INT32 || rets_[0].dtype() == DT_INT64;
+ if (!is_hostmem_dtype && gpu_device_info &&
+ (opts_.rets_alloc_attrs.empty() ||
+ !opts_.rets_alloc_attrs[0].on_host())) {
+ // Copy the ret value to host if it's allocated on device.
+ Device* device = down_cast<Device*>(ctx_->device());
+ DeviceContext* device_ctx = ctx_->op_device_context();
+ cond_t = Tensor(rets_[0].dtype(), rets_[0].shape());
+ Notification done_copy;
+ device_ctx->CopyDeviceTensorToCPU(
+ &rets_[0], /*tensor_name=*/"", device, &cond_t,
+ [&done_copy, &s](const Status& status) {
+ s = status;
+ done_copy.Notify();
+ });
+ done_copy.WaitForNotification();
+ if (!s.ok()) {
+ return Finish(s);
+ }
+ } else {
+ cond_t = rets_[0];
+ }
+#else
+ cond_t = rets_[0];
+#endif
bool cond;
- Status s = ToBool(rets_, &cond);
+ s = ToBool({cond_t}, &cond);
+
if (!s.ok()) {
return Finish(s);
}
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index 903b898d0a..2b010f816d 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/kernels/non_max_suppression_op.h"
+#include <queue>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -56,20 +57,9 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
errors::InvalidArgument("scores has incompatible shape"));
}
-static inline void DecreasingArgSort(const std::vector<float>& values,
- std::vector<int>* indices) {
- indices->resize(values.size());
- for (int i = 0; i < values.size(); ++i) (*indices)[i] = i;
- std::sort(
- indices->begin(), indices->end(),
- [&values](const int i, const int j) { return values[i] > values[j]; });
-}
-
-// Return true if intersection-over-union overlap between boxes i and j
-// is greater than iou_threshold.
-static inline bool IOUGreaterThanThreshold(
- typename TTypes<float, 2>::ConstTensor boxes, int i, int j,
- float iou_threshold) {
+// Return intersection-over-union overlap between boxes i and j
+static inline float IOU(typename TTypes<float, 2>::ConstTensor boxes, int i,
+ int j) {
const float ymin_i = std::min<float>(boxes(i, 0), boxes(i, 2));
const float xmin_i = std::min<float>(boxes(i, 1), boxes(i, 3));
const float ymax_i = std::max<float>(boxes(i, 0), boxes(i, 2));
@@ -88,13 +78,13 @@ static inline bool IOUGreaterThanThreshold(
const float intersection_area =
std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
std::max<float>(intersection_xmax - intersection_xmin, 0.0);
- const float iou = intersection_area / (area_i + area_j - intersection_area);
- return iou > iou_threshold;
+ return intersection_area / (area_i + area_j - intersection_area);
}
void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
const Tensor& scores, const Tensor& max_output_size,
- const float iou_threshold) {
+ const float iou_threshold,
+ const float score_threshold) {
OP_REQUIRES(context, iou_threshold >= 0 && iou_threshold <= 1,
errors::InvalidArgument("iou_threshold must be in [0, 1]"));
@@ -109,37 +99,61 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
std::vector<float> scores_data(num_boxes);
std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
- std::vector<int> sorted_indices;
- DecreasingArgSort(scores_data, &sorted_indices);
+
+ // Data structure for selection candidate in NMS.
+ struct Candidate {
+ int box_index;
+ float score;
+ };
+
+ auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
+ return bs_i.score < bs_j.score;
+ };
+ std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)>
+ candidate_priority_queue(cmp);
+ for (int i = 0; i < scores_data.size(); ++i) {
+ if (scores_data[i] > score_threshold) {
+ candidate_priority_queue.emplace(Candidate({i, scores_data[i]}));
+ }
+ }
+
+ auto suppress_func = [iou_threshold](const float x) {
+ return x <= iou_threshold ? 1 : 0;
+ };
std::vector<int> selected;
- std::vector<int> selected_indices(output_size, 0);
- int num_selected = 0;
- for (int i = 0; i < num_boxes; ++i) {
- if (selected.size() >= output_size) break;
- bool should_select = true;
+ std::vector<float> selected_scores;
+ Candidate next_candidate;
+ float iou, original_score;
+
+ while (selected.size() < output_size && !candidate_priority_queue.empty()) {
+ next_candidate = candidate_priority_queue.top();
+ original_score = next_candidate.score;
+ candidate_priority_queue.pop();
+
// Overlapping boxes are likely to have similar scores,
- // therefore we iterate through the selected boxes backwards.
- for (int j = num_selected - 1; j >= 0; --j) {
- if (IOUGreaterThanThreshold(boxes_data, sorted_indices[i],
- sorted_indices[selected_indices[j]],
- iou_threshold)) {
- should_select = false;
- break;
- }
+ // therefore we iterate through the previously selected boxes backwards
+ // in order to see if `next_candidate` should be suppressed.
+ for (int j = selected.size() - 1; j >= 0; --j) {
+ iou = IOU(boxes_data, next_candidate.box_index, selected[j]);
+ if (iou == 0.0) continue;
+ next_candidate.score *= suppress_func(iou);
+ if (next_candidate.score <= score_threshold) break;
}
- if (should_select) {
- selected.push_back(sorted_indices[i]);
- selected_indices[num_selected++] = i;
+
+ if (original_score == next_candidate.score) {
+ selected.push_back(next_candidate.box_index);
+ selected_scores.push_back(next_candidate.score);
}
}
- // Allocate output tensor
- Tensor* output = nullptr;
+ // Allocate output tensors
+ Tensor* output_indices = nullptr;
TensorShape output_shape({static_cast<int>(selected.size())});
- OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
- TTypes<int, 1>::Tensor selected_indices_data = output->tensor<int, 1>();
- std::copy_n(selected.begin(), selected.size(), selected_indices_data.data());
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, output_shape, &output_indices));
+ TTypes<int, 1>::Tensor output_indices_data = output_indices->tensor<int, 1>();
+ std::copy_n(selected.begin(), selected.size(), output_indices_data.data());
}
} // namespace
@@ -164,8 +178,9 @@ class NonMaxSuppressionOp : public OpKernel {
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
max_output_size.shape().DebugString()));
+ const float score_threshold_val = 0.0;
DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
- iou_threshold_);
+ iou_threshold_, score_threshold_val);
}
private:
@@ -194,11 +209,48 @@ class NonMaxSuppressionV2Op : public OpKernel {
OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
iou_threshold.shape().DebugString()));
+ const float iou_threshold_val = iou_threshold.scalar<float>()();
+ const float score_threshold_val = 0.0;
+ DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
+ iou_threshold_val, score_threshold_val);
+ }
+};
+
+template <typename Device>
+class NonMaxSuppressionV3Op : public OpKernel {
+ public:
+ explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // boxes: [num_boxes, 4]
+ const Tensor& boxes = context->input(0);
+ // scores: [num_boxes]
+ const Tensor& scores = context->input(1);
+ // max_output_size: scalar
+ const Tensor& max_output_size = context->input(2);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(max_output_size.shape()),
+ errors::InvalidArgument("max_output_size must be 0-D, got shape ",
+ max_output_size.shape().DebugString()));
+ // iou_threshold: scalar
+ const Tensor& iou_threshold = context->input(3);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
+ errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
+ iou_threshold.shape().DebugString()));
const float iou_threshold_val = iou_threshold.scalar<float>()();
+ // score_threshold: scalar
+ const Tensor& score_threshold = context->input(4);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(score_threshold.shape()),
+ errors::InvalidArgument("score_threshold must be 0-D, got shape ",
+ score_threshold.shape().DebugString()));
+ const float score_threshold_val = score_threshold.scalar<float>()();
+
DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
- iou_threshold_val);
+ iou_threshold_val, score_threshold_val);
}
};
@@ -208,4 +260,7 @@ REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
NonMaxSuppressionV2Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU),
+ NonMaxSuppressionV3Op<CPUDevice>);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/non_max_suppression_op.h b/tensorflow/core/kernels/non_max_suppression_op.h
index d4349edf17..933b1af447 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.h
+++ b/tensorflow/core/kernels/non_max_suppression_op.h
@@ -27,7 +27,8 @@ template <typename Device, typename T>
struct NonMaxSuppression {
void operator()(const Device& d, typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<float, 1>::ConstTensor scores,
- float iou_threshold, int max_output_size,
+ float iou_threshold, float score_threshold,
+ int max_output_size,
typename TTypes<int, 1>::Tensor selected_indices);
};
diff --git a/tensorflow/core/kernels/non_max_suppression_op_test.cc b/tensorflow/core/kernels/non_max_suppression_op_test.cc
index 9387fb13bc..c71aa23e01 100644
--- a/tensorflow/core/kernels/non_max_suppression_op_test.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op_test.cc
@@ -340,4 +340,195 @@ TEST_F(NonMaxSuppressionV2OpTest, TestEmptyInput) {
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
}
+//
+// NonMaxSuppressionV3Op Tests
+//
+
+class NonMaxSuppressionV3OpTest : public OpsTestBase {
+ protected:
+ void MakeOp() {
+ TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op", "NonMaxSuppressionV3")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ }
+};
+
+TEST_F(NonMaxSuppressionV3OpTest, TestSelectFromThreeClusters) {
+ MakeOp();
+ AddInputFromArray<float>(
+ TensorShape({6, 4}),
+ {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({3}));
+ test::FillValues<int>(&expected, {3, 0, 5});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest,
+ TestSelectFromThreeClustersWithScoreThreshold) {
+ MakeOp();
+ AddInputFromArray<float>(
+ TensorShape({6, 4}),
+ {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {0.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.4f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({2}));
+ test::FillValues<int>(&expected, {3, 0});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest,
+ TestSelectFromThreeClustersFlippedCoordinates) {
+ MakeOp();
+ AddInputFromArray<float>(TensorShape({6, 4}),
+ {1, 1, 0, 0, 0, 0.1f, 1, 1.1f, 0, .9f, 1, -0.1f,
+ 0, 10, 1, 11, 1, 10.1f, 0, 11.1f, 1, 101, 0, 100});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({3}));
+ test::FillValues<int>(&expected, {3, 0, 5});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest, TestSelectAtMostTwoBoxesFromThreeClusters) {
+ MakeOp();
+ AddInputFromArray<float>(
+ TensorShape({6, 4}),
+ {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {2});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({2}));
+ test::FillValues<int>(&expected, {3, 0});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest,
+ TestSelectAtMostThirtyBoxesFromThreeClusters) {
+ MakeOp();
+ AddInputFromArray<float>(
+ TensorShape({6, 4}),
+ {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {30});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({3}));
+ test::FillValues<int>(&expected, {3, 0, 5});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest, TestSelectSingleBox) {
+ MakeOp();
+ AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
+ AddInputFromArray<float>(TensorShape({1}), {.9f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({1}));
+ test::FillValues<int>(&expected, {0});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest, TestSelectFromTenIdenticalBoxes) {
+ MakeOp();
+
+ int num_boxes = 10;
+ std::vector<float> corners(num_boxes * 4);
+ std::vector<float> scores(num_boxes);
+ for (int i = 0; i < num_boxes; ++i) {
+ corners[i * 4 + 0] = 0;
+ corners[i * 4 + 1] = 0;
+ corners[i * 4 + 2] = 1;
+ corners[i * 4 + 3] = 1;
+ scores[i] = .9;
+ }
+ AddInputFromArray<float>(TensorShape({num_boxes, 4}), corners);
+ AddInputFromArray<float>(TensorShape({num_boxes}), scores);
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({1}));
+ test::FillValues<int>(&expected, {0});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest, TestInconsistentBoxAndScoreShapes) {
+ MakeOp();
+ AddInputFromArray<float>(
+ TensorShape({6, 4}),
+ {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f});
+ AddInputFromArray<int>(TensorShape({}), {30});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ Status s = RunOpKernel();
+
+ ASSERT_FALSE(s.ok());
+ EXPECT_TRUE(
+ str_util::StrContains(s.ToString(), "scores has incompatible shape"))
+ << s;
+}
+
+TEST_F(NonMaxSuppressionV3OpTest, TestInvalidIOUThreshold) {
+ MakeOp();
+ AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
+ AddInputFromArray<float>(TensorShape({1}), {.9f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {1.2f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ Status s = RunOpKernel();
+
+ ASSERT_FALSE(s.ok());
+ EXPECT_TRUE(
+ str_util::StrContains(s.ToString(), "iou_threshold must be in [0, 1]"))
+ << s;
+}
+
+TEST_F(NonMaxSuppressionV3OpTest, TestEmptyInput) {
+ MakeOp();
+ AddInputFromArray<float>(TensorShape({0, 4}), {});
+ AddInputFromArray<float>(TensorShape({0}), {});
+ AddInputFromArray<int>(TensorShape({}), {30});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({0}));
+ test::FillValues<int>(&expected, {});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/ops_testutil.cc b/tensorflow/core/kernels/ops_testutil.cc
index cd13d31bbc..7aa7d1a586 100644
--- a/tensorflow/core/kernels/ops_testutil.cc
+++ b/tensorflow/core/kernels/ops_testutil.cc
@@ -24,7 +24,7 @@ namespace tensorflow {
void OpsTestBase::SetDevice(const DeviceType& device_type,
std::unique_ptr<Device> device) {
- CHECK(device_.get()) << "No device provided";
+ CHECK(device_) << "No device provided";
device_type_ = device_type;
device_ = std::move(device);
#ifdef GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/quantize_op.cc b/tensorflow/core/kernels/quantize_op.cc
index fc26813a08..857273e04e 100644
--- a/tensorflow/core/kernels/quantize_op.cc
+++ b/tensorflow/core/kernels/quantize_op.cc
@@ -131,6 +131,7 @@ class QuantizeV2Op : public OpKernel {
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
+ typename TTypes<T>::Vec o = output->template flat<T>();
if (mode_ == QUANTIZE_MODE_MIN_COMBINED) {
const float scale_factor =
(static_cast<double>(std::numeric_limits<T>::max()) -
@@ -147,7 +148,6 @@ class QuantizeV2Op : public OpKernel {
// semantic of std::round, which implements "round-half-away-zero",
// e.g., -5.5 gets rounded to -6, -5.4 goes to -5, 5.4 goes to 5,
// and 5.5 goes to 6.
- typename TTypes<T>::Vec o = output->template flat<T>();
bool is_signed = std::is_signed<T>::value;
if (is_signed) {
// The slow path.
@@ -180,29 +180,20 @@ class QuantizeV2Op : public OpKernel {
output);
}
} else if (mode_ == QUANTIZE_MODE_SCALED) {
- // The quantization logic for mode SCALED matches that of
- // QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3.
- typename TTypes<T>::Vec o = output->template flat<T>();
- static constexpr int num_bits = sizeof(T) * 8;
- const float max_abs = std::max(std::abs(min_range), std::abs(max_range));
- const bool is_signed = std::is_signed<T>::value;
- float target_range;
- if (is_signed) {
- max_range = max_abs;
- min_range = -max_abs;
- // If it is signed, we try to keep 0.0 being 0 and drop one bucket. For
- // example, if it is 8 bits, we have the range [-127, 127]. So for input
- // range of [-x, x], the scale should be 254/(2*x).
- target_range = static_cast<float>((uint64_t{1} << (num_bits - 1)) - 1);
- } else {
- max_range = max_abs;
- min_range = 0.0;
- // If it is unsigned and num_bits == 8, the range with 8 bits is [0,
- // 255]. If the input range is [0, x], then the scale is x/255 instead
- // of 254 as in the case above.
- target_range = static_cast<float>((uint64_t{1} << num_bits) - 1);
- }
- const float scale_factor = target_range / max_abs;
+ const int min_output_value = std::numeric_limits<T>::min();
+ const int max_output_value = std::numeric_limits<T>::max();
+ const float scale_factor_from_min_side =
+ (min_output_value * min_range > 0)
+ ? min_output_value / min_range
+ : std::numeric_limits<float>::max();
+ const float scale_factor_from_max_side =
+ (max_output_value * max_range > 0)
+ ? max_output_value / max_range
+ : std::numeric_limits<float>::max();
+ const float scale_factor =
+ std::min(scale_factor_from_min_side, scale_factor_from_max_side);
+ min_range = min_output_value / scale_factor;
+ max_range = max_output_value / scale_factor;
if (round_mode_ == ROUND_HALF_TO_EVEN) {
// scalar_round_op_google implements "round-half-to-even".
o.device(ctx->template eigen_device<Device>()) =
diff --git a/tensorflow/core/kernels/quantize_op_test.cc b/tensorflow/core/kernels/quantize_op_test.cc
index 57982bdf76..0a672686a2 100644
--- a/tensorflow/core/kernels/quantize_op_test.cc
+++ b/tensorflow/core/kernels/quantize_op_test.cc
@@ -61,17 +61,17 @@ TEST_F(QuantizedOpTest, QuantizeV2Quint8Scaled) {
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<float>(TensorShape({8}),
- {-255.0, 0.0, 1.0, 1.25, 1.75, 127.0, 255.0, 500.0});
+ {-255.0, 0.0, 1.0, 1.25, 1.75, 64.0, 127.0, 500.0});
AddInputFromArray<float>(TensorShape({1}), {-255.0f});
AddInputFromArray<float>(TensorShape({1}), {127.0f});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QUINT8, TensorShape({8}));
- // Input element -5.0 should map to 0 even though min_range = -255, because
+ // Input values < 0 should map to 0 even though min_range = -255, because
// we are performing quantization by scaling to quint8.
- // Input element 0.0 should map to 0.
- // Input element 500.0 is quantized to 127 because
- // max(abs(-255), abs(127)) = 255.
- test::FillValues<quint8>(&expected, {0, 0, 1, 1, 2, 127, 255, 255});
+ // Input value 0.0 should map to 0.
+ // The scale factor chosen should be 255 / 127 = 2.00787
+ // Output values are clipped to 255.
+ test::FillValues<quint8>(&expected, {0, 0, 2, 3, 4, 129, 255, 255});
test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
Tensor expected_output_min(allocator(), DT_FLOAT, TensorShape({}));
@@ -79,7 +79,7 @@ TEST_F(QuantizedOpTest, QuantizeV2Quint8Scaled) {
test::ExpectTensorEqual<float>(expected_output_min, *GetOutput(1));
Tensor expected_output_max(allocator(), DT_FLOAT, TensorShape({}));
- test::FillValues<float>(&expected_output_max, {255.0});
+ test::FillValues<float>(&expected_output_max, {127.0});
test::ExpectTensorEqual<float>(expected_output_max, *GetOutput(2));
}
@@ -123,19 +123,19 @@ TEST_F(QuantizedOpTest, QuantizeV2Qint8Scaled) {
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<float>(TensorShape({7}),
- {-127.0, 0.0, 1.0, 1.25, 1.75, 64.0, 127.0});
- AddInputFromArray<float>(TensorShape({1}), {-127.0f});
+ {-128.0, 0.0, 1.0, 1.25, 1.75, 64.0, 127.0});
+ AddInputFromArray<float>(TensorShape({1}), {-128.0f});
AddInputFromArray<float>(TensorShape({1}), {100.0f});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QINT8, TensorShape({7}));
// Input element 0.0 should map to 0.
// Input element 127.0 maps to 127 instead of 100 because
// max(abs(-127), abs(100)) = 127.
- test::FillValues<qint8>(&expected, {-127, 0, 1, 1, 2, 64, 127});
+ test::FillValues<qint8>(&expected, {-128, 0, 1, 1, 2, 64, 127});
test::ExpectTensorEqual<qint8>(expected, *GetOutput(0));
Tensor expected_output_min(allocator(), DT_FLOAT, TensorShape({}));
- test::FillValues<float>(&expected_output_min, {-127.0});
+ test::FillValues<float>(&expected_output_min, {-128.0});
test::ExpectTensorEqual<float>(expected_output_min, *GetOutput(1));
Tensor expected_output_max(allocator(), DT_FLOAT, TensorShape({}));
@@ -152,9 +152,9 @@ TEST_F(QuantizedOpTest, QuantizeV2Qint8ScaledSmallInputRange) {
.Attr("mode", "SCALED")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
- AddInputFromArray<float>(TensorShape({3}), {-1.0, 0.0, 2.0});
- AddInputFromArray<float>(TensorShape({1}), {-1.0f});
- AddInputFromArray<float>(TensorShape({1}), {2.0f});
+ AddInputFromArray<float>(TensorShape({3}), {-0.064, 0.0, 0.127});
+ AddInputFromArray<float>(TensorShape({1}), {-0.064f});
+ AddInputFromArray<float>(TensorShape({1}), {0.127f});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QINT8, TensorShape({3}));
// Input element 0.0 should map to 0.
@@ -163,11 +163,11 @@ TEST_F(QuantizedOpTest, QuantizeV2Qint8ScaledSmallInputRange) {
test::ExpectTensorEqual<qint8>(expected, *GetOutput(0));
Tensor expected_output_min(allocator(), DT_FLOAT, TensorShape({}));
- test::FillValues<float>(&expected_output_min, {-2.0});
+ test::FillValues<float>(&expected_output_min, {-0.128});
test::ExpectTensorEqual<float>(expected_output_min, *GetOutput(1));
Tensor expected_output_max(allocator(), DT_FLOAT, TensorShape({}));
- test::FillValues<float>(&expected_output_max, {2.0});
+ test::FillValues<float>(&expected_output_max, {0.127});
test::ExpectTensorEqual<float>(expected_output_max, *GetOutput(2));
}
@@ -183,8 +183,8 @@ TEST_F(QuantizedOpTest, QuantizeV2Qint8ScaledRoundToEven) {
TF_ASSERT_OK(InitOp());
AddInputFromArray<float>(TensorShape({7}),
{-126.5, 0.0, 1.0, 2.5, 3.5, 64.0, 127.0});
- AddInputFromArray<float>(TensorShape({1}), {-127.0f});
- AddInputFromArray<float>(TensorShape({1}), {-127.0f});
+ AddInputFromArray<float>(TensorShape({1}), {-128.0f});
+ AddInputFromArray<float>(TensorShape({1}), {-128.0f});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QINT8, TensorShape({7}));
// Input element 0.0 should map to 0.
@@ -193,7 +193,7 @@ TEST_F(QuantizedOpTest, QuantizeV2Qint8ScaledRoundToEven) {
test::ExpectTensorEqual<qint8>(expected, *GetOutput(0));
Tensor expected_output_min(allocator(), DT_FLOAT, TensorShape({}));
- test::FillValues<float>(&expected_output_min, {-127.0});
+ test::FillValues<float>(&expected_output_min, {-128.0});
test::ExpectTensorEqual<float>(expected_output_min, *GetOutput(1));
Tensor expected_output_max(allocator(), DT_FLOAT, TensorShape({}));
@@ -213,8 +213,8 @@ TEST_F(QuantizedOpTest, QuantizeV2Qint8ScaledRoundAwayFromZero) {
TF_ASSERT_OK(InitOp());
AddInputFromArray<float>(TensorShape({7}),
{-126.5, 0.0, 1.0, 2.5, 3.5, 64.0, 127.0});
- AddInputFromArray<float>(TensorShape({1}), {-127.0f});
- AddInputFromArray<float>(TensorShape({1}), {-127.0f});
+ AddInputFromArray<float>(TensorShape({1}), {-128.0f});
+ AddInputFromArray<float>(TensorShape({1}), {-128.0f});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QINT8, TensorShape({7}));
// Input element 0.0 should map to 0.
@@ -223,7 +223,7 @@ TEST_F(QuantizedOpTest, QuantizeV2Qint8ScaledRoundAwayFromZero) {
test::ExpectTensorEqual<qint8>(expected, *GetOutput(0));
Tensor expected_output_min(allocator(), DT_FLOAT, TensorShape({}));
- test::FillValues<float>(&expected_output_min, {-127.0});
+ test::FillValues<float>(&expected_output_min, {-128.0});
test::ExpectTensorEqual<float>(expected_output_min, *GetOutput(1));
Tensor expected_output_max(allocator(), DT_FLOAT, TensorShape({}));
diff --git a/tensorflow/core/kernels/regex_full_match_op.cc b/tensorflow/core/kernels/regex_full_match_op.cc
new file mode 100644
index 0000000000..5863a2c8e4
--- /dev/null
+++ b/tensorflow/core/kernels/regex_full_match_op.cc
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "re2/re2.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class RegexFullMatchOp : public OpKernel {
+ public:
+ explicit RegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
+ const auto& input_flat = input_tensor->flat<string>();
+
+ const Tensor* pattern_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
+ errors::InvalidArgument("Pattern must be scalar, but received ",
+ pattern_tensor->shape().DebugString()));
+ const string pattern = pattern_tensor->flat<string>()(0);
+ const RE2 match(pattern);
+ OP_REQUIRES(ctx, match.ok(),
+ errors::InvalidArgument("Invalid pattern: ", pattern,
+ ", error: ", match.error()));
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<bool>();
+ for (size_t i = 0; i < input_flat.size(); ++i) {
+ output_flat(i) = RE2::FullMatch(input_flat(i), match);
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),
+ RegexFullMatchOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/roll_op.cc b/tensorflow/core/kernels/roll_op.cc
index f5ebf0ea2e..722116f86f 100644
--- a/tensorflow/core/kernels/roll_op.cc
+++ b/tensorflow/core/kernels/roll_op.cc
@@ -285,7 +285,7 @@ class RollOp : public OpKernel {
dim_range[i] = dim_size_prod;
}
- Tensor* output = NULL;
+ Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
auto input_flat = input.flat<T>().data();
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index 0caa7bd317..8ef6e77398 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -62,14 +62,57 @@ class ScatterNdOp : public OpKernel {
const Tensor& updates = c->input(1);
const Tensor& shape_input = c->input(2);
- OP_REQUIRES(c, shape_input.dims() == 1,
- errors::InvalidArgument("Shape must be a vector"));
+ OP_REQUIRES(c, indices.shape().dims() >= 1,
+ errors::InvalidArgument(
+ "Indices shape must have rank at least one. Found:",
+ indices.shape().DebugString()));
+ OP_REQUIRES(c, updates.shape().dims() >= 1,
+ errors::InvalidArgument(
+ "Updates shape must have rank at least one. Found:",
+ updates.shape().DebugString()));
auto vec = shape_input.flat<Index>();
TensorShape shape;
OP_REQUIRES_OK(c,
TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape));
+ OP_REQUIRES(
+ c,
+ (shape.num_elements() > 0 || (indices.shape().num_elements() == 0 &&
+ updates.shape().num_elements() == 0)),
+ errors::InvalidArgument(
+ "Indices and updates specified for empty output shape"));
+
+ const int64 outer_dims = indices.shape().dims() - 1;
+
+ for (int i = 0; i < outer_dims; ++i) {
+ OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
+ errors::InvalidArgument(
+ "Outer dimensions of indices and update must match. "
+ "Indices shape: ",
+ indices.shape().DebugString(),
+ ", updates shape:", updates.shape().DebugString()));
+ }
+
+ const int64 ix = indices.shape().dim_size(outer_dims);
+ OP_REQUIRES(
+ c, updates.shape().dims() - outer_dims == shape.dims() - ix,
+ errors::InvalidArgument("Inner dimensions of output shape must match "
+ "inner dimensions of updates shape. Output: ",
+ shape.DebugString(),
+ " updates: ", updates.shape().DebugString()));
+ for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
+ OP_REQUIRES(
+ c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
+ errors::InvalidArgument(
+ "The inner ", shape.dims() - ix,
+ " dimensions of output.shape=", shape.DebugString(),
+ " must match the inner ", updates.shape().dims() - outer_dims,
+ " dimensions of updates.shape=", updates.shape().DebugString()));
+ }
+ OP_REQUIRES(c, shape_input.dims() == 1,
+ errors::InvalidArgument("Shape must be a vector"));
+
Tensor out;
OP_REQUIRES_OK(
c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>(
diff --git a/tensorflow/core/kernels/scoped_allocator_ops.cc b/tensorflow/core/kernels/scoped_allocator_ops.cc
index 1800ee8c1f..1d2fb6996a 100644
--- a/tensorflow/core/kernels/scoped_allocator_ops.cc
+++ b/tensorflow/core/kernels/scoped_allocator_ops.cc
@@ -113,7 +113,7 @@ class ScopedAllocatorConcatOp : public OpKernel {
OP_REQUIRES(context, backing_tensor.NumElements() >= shape_.num_elements(),
errors::InvalidArgument("Backing tensor num elements ",
backing_tensor.NumElements(),
- " is not equal to expected ",
+ " is not >= to expected ",
shape_.num_elements()));
Tensor output(dtype_);
if (reshape_) {
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc
index c87ce78e05..2328fc6afd 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops.cc
@@ -320,7 +320,9 @@ class SegmentSumGPUOp : public AsyncOpKernel {
REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
type, index_type, 0); \
REGISTER_CPU_KERNEL_SEGMENT( \
- "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1)
+ "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
+ REGISTER_CPU_KERNEL_SEGMENT( \
+ "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1);
#define REGISTER_REAL_CPU_KERNELS_ALL(type) \
REGISTER_REAL_CPU_KERNELS(type, int32); \
diff --git a/tensorflow/core/lib/gtl/flatmap.h b/tensorflow/core/lib/gtl/flatmap.h
index 889d2ddaa6..9dc439c163 100644
--- a/tensorflow/core/lib/gtl/flatmap.h
+++ b/tensorflow/core/lib/gtl/flatmap.h
@@ -76,6 +76,10 @@ class FlatMap {
FlatMap(const FlatMap& src) : rep_(src.rep_) {}
+ // Move constructor leaves src in a valid but unspecified state (same as
+ // std::unordered_map).
+ FlatMap(FlatMap&& src) : rep_(std::move(src.rep_)) {}
+
template <typename InputIter>
FlatMap(InputIter first, InputIter last, size_t N = 1,
const Hash& hf = Hash(), const Eq& eq = Eq())
@@ -92,6 +96,13 @@ class FlatMap {
return *this;
}
+ // Move-assignment operator leaves src in a valid but unspecified state (same
+ // as std::unordered_map).
+ FlatMap& operator=(FlatMap&& src) {
+ rep_.MoveFrom(std::move(src.rep_));
+ return *this;
+ }
+
~FlatMap() {}
void swap(FlatMap& x) { rep_.swap(x.rep_); }
diff --git a/tensorflow/core/lib/gtl/flatmap_test.cc b/tensorflow/core/lib/gtl/flatmap_test.cc
index 0901eba926..0fd22ab37b 100644
--- a/tensorflow/core/lib/gtl/flatmap_test.cc
+++ b/tensorflow/core/lib/gtl/flatmap_test.cc
@@ -656,19 +656,33 @@ TEST(FlatMap, UniqueMap) {
}
EXPECT_EQ(map.size(), N);
+ // move constructor
+ UniqMap map2(std::move(map));
+
// Lookups
for (int i = 0; i < N; i++) {
- EXPECT_EQ(*map.at(MakeUniq(i)), i + 100);
+ EXPECT_EQ(*map2.at(MakeUniq(i)), i + 100);
}
+ // move assignment
+ UniqMap map3;
+ map3 = std::move(map2);
+
// find+erase
- EXPECT_EQ(map.count(MakeUniq(2)), 1);
- map.erase(MakeUniq(2));
- EXPECT_EQ(map.count(MakeUniq(2)), 0);
+ EXPECT_EQ(map3.count(MakeUniq(2)), 1);
+ map3.erase(MakeUniq(2));
+ EXPECT_EQ(map3.count(MakeUniq(2)), 0);
// clear
- map.clear();
- EXPECT_EQ(map.size(), 0);
+ map3.clear();
+ EXPECT_EQ(map3.size(), 0);
+
+ // Check that moved-from maps are in a valid (though unspecified) state.
+ EXPECT_GE(map.size(), 0);
+ EXPECT_GE(map2.size(), 0);
+ // This insert should succeed no matter what state `map` is in, because
+ // MakeUniq(-1) is never called above: This key can't possibly exist.
+ EXPECT_TRUE(map.emplace(MakeUniq(-1), MakeUniq(-1)).second);
}
TEST(FlatMap, UniqueMapIter) {
diff --git a/tensorflow/core/lib/gtl/flatrep.h b/tensorflow/core/lib/gtl/flatrep.h
index 0d7e7487fc..65a076b0f3 100644
--- a/tensorflow/core/lib/gtl/flatrep.h
+++ b/tensorflow/core/lib/gtl/flatrep.h
@@ -51,10 +51,23 @@ class FlatRep {
FlatRep(size_t N, const Hash& hf, const Eq& eq) : hash_(hf), equal_(eq) {
Init(N);
}
- explicit FlatRep(const FlatRep& src) : hash_(src.hash_), equal_(src.equal_) {
+ FlatRep(const FlatRep& src) : hash_(src.hash_), equal_(src.equal_) {
Init(src.size());
CopyEntries(src.array_, src.end_, CopyEntry());
}
+
+ FlatRep(FlatRep&& src)
+ // Copy rather than move src.hash_ and src.equal_. This is necessary to
+ // leave src in a valid state -- otherwise e.g. if hash_ is an
+ // std::function, moving it would null it out.
+ : hash_(src.hash_), equal_(src.equal_) {
+ // TODO(jlebar): Init(1) still allocates some memory, so this isn't as cheap
+ // as it could be. The fundamental problem is that we need to leave src in
+ // a valid state, and FlatRep *always* owns a nonzero amount of memory.
+ Init(1);
+ swap(src);
+ }
+
~FlatRep() {
clear_no_resize();
delete[] array_;
@@ -78,6 +91,12 @@ class FlatRep {
}
}
+ void MoveFrom(FlatRep&& src) {
+ if (this != &src) {
+ swap(src);
+ }
+ }
+
void clear_no_resize() {
for (Bucket* b = array_; b != end_; b++) {
for (uint32 i = 0; i < kWidth; i++) {
diff --git a/tensorflow/core/lib/gtl/flatset.h b/tensorflow/core/lib/gtl/flatset.h
index f31e3abe41..bb4356e46d 100644
--- a/tensorflow/core/lib/gtl/flatset.h
+++ b/tensorflow/core/lib/gtl/flatset.h
@@ -59,6 +59,10 @@ class FlatSet {
FlatSet(const FlatSet& src) : rep_(src.rep_) {}
+ // Move constructor leaves src in a valid but unspecified state (same as
+ // std::unordered_set).
+ FlatSet(FlatSet&& src) : rep_(std::move(src.rep_)) {}
+
template <typename InputIter>
FlatSet(InputIter first, InputIter last, size_t N = 1,
const Hash& hf = Hash(), const Eq& eq = Eq())
@@ -75,6 +79,13 @@ class FlatSet {
return *this;
}
+ // Move-assignment operator leaves src in a valid but unspecified state (same
+ // as std::unordered_set).
+ FlatSet& operator=(FlatSet&& src) {
+ rep_.MoveFrom(std::move(src.rep_));
+ return *this;
+ }
+
~FlatSet() {}
void swap(FlatSet& x) { rep_.swap(x.rep_); }
@@ -169,6 +180,7 @@ class FlatSet {
}
std::pair<iterator, bool> insert(const Key& k) { return Insert(k); }
+ std::pair<iterator, bool> insert(Key&& k) { return Insert(std::move(k)); }
template <typename InputIter>
void insert(InputIter first, InputIter last) {
for (; first != last; ++first) {
@@ -265,9 +277,10 @@ class FlatSet {
}
};
- std::pair<iterator, bool> Insert(const Key& k) {
+ template <typename K>
+ std::pair<iterator, bool> Insert(K&& k) {
rep_.MaybeResize();
- auto r = rep_.FindOrInsert(k);
+ auto r = rep_.FindOrInsert(std::forward<K>(k));
const bool inserted = !r.found;
return {iterator(r.b, rep_.limit(), r.index), inserted};
}
diff --git a/tensorflow/core/lib/gtl/flatset_test.cc b/tensorflow/core/lib/gtl/flatset_test.cc
index 010b4bb5df..7f0138404f 100644
--- a/tensorflow/core/lib/gtl/flatset_test.cc
+++ b/tensorflow/core/lib/gtl/flatset_test.cc
@@ -552,18 +552,32 @@ TEST(FlatSet, UniqueSet) {
}
EXPECT_EQ(set.size(), N);
+ // Move constructor
+ UniqSet set2(std::move(set));
+
// Lookups
for (int i = 0; i < N; i++) {
- EXPECT_EQ(set.count(MakeUniq(i)), 1);
+ EXPECT_EQ(set2.count(MakeUniq(i)), 1);
}
+ // Move-assignment operator
+ UniqSet set3;
+ set3 = std::move(set2);
+
// erase
- set.erase(MakeUniq(2));
- EXPECT_EQ(set.count(MakeUniq(2)), 0);
+ set3.erase(MakeUniq(2));
+ EXPECT_EQ(set3.count(MakeUniq(2)), 0);
// clear
set.clear();
EXPECT_EQ(set.size(), 0);
+
+ // Check that moved-from sets are in a valid (though unspecified) state.
+ EXPECT_GE(set.size(), 0);
+ EXPECT_GE(set2.size(), 0);
+ // This insert should succeed no matter what state `set` is in, because
+ // MakeUniq(-1) is never called above: This key can't possibly exist.
+ EXPECT_TRUE(set.emplace(MakeUniq(-1)).second);
}
TEST(FlatSet, UniqueSetIter) {
@@ -579,6 +593,12 @@ TEST(FlatSet, UniqueSetIter) {
EXPECT_EQ(sum, (kCount * (kCount + 1)) / 2);
}
+TEST(FlatSet, InsertUncopyable) {
+ UniqSet set;
+ EXPECT_TRUE(set.insert(MakeUniq(0)).second);
+ EXPECT_EQ(set.size(), 1);
+}
+
/* This would be a good negative compilation test, if we could do that.
TEST(FlatSet, MutableIterator_ShouldNotCompile) {
diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h
index 3f85303c0f..737d23f699 100644
--- a/tensorflow/core/lib/hash/hash.h
+++ b/tensorflow/core/lib/hash/hash.h
@@ -44,6 +44,12 @@ inline uint64 Hash64Combine(uint64 a, uint64 b) {
return a ^ (b + 0x9e3779b97f4a7800ULL + (a << 10) + (a >> 4));
}
+// Combine two hashes in an order-independent way. This operation should be
+// associative and compute the same hash for a collection of elements
+// independent of traversal order. Note that it is better to combine hashes
+// symmetrically with addition rather than XOR, since (x^x) == 0 but (x+x) != 0.
+inline uint64 Hash64CombineUnordered(uint64 a, uint64 b) { return a + b; }
+
// Hash functor suitable for use with power-of-two sized hashtables. Use
// instead of std::hash<T>.
//
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 6880ceb505..c867674489 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -14642,6 +14642,66 @@ op {
}
}
op {
+ name: "CropAndResize"
+ input_arg {
+ name: "image"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "boxes"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "box_ind"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "crop_size"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "crops"
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "method"
+ type: "string"
+ default_value {
+ s: "bilinear"
+ }
+ allowed_values {
+ list {
+ s: "bilinear"
+ s: "nearest"
+ }
+ }
+ }
+ attr {
+ name: "extrapolation_value"
+ type: "float"
+ default_value {
+ f: 0
+ }
+ }
+}
+op {
name: "CropAndResizeGradBoxes"
input_arg {
name: "grads"
@@ -14791,6 +14851,53 @@ op {
}
}
op {
+ name: "CropAndResizeGradImage"
+ input_arg {
+ name: "grads"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "boxes"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "box_ind"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "image_size"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "method"
+ type: "string"
+ default_value {
+ s: "bilinear"
+ }
+ allowed_values {
+ list {
+ s: "bilinear"
+ s: "nearest"
+ }
+ }
+ }
+}
+op {
name: "Cross"
input_arg {
name: "a"
@@ -34457,6 +34564,33 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV3"
+ input_arg {
+ name: "boxes"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "scores"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+}
+op {
name: "NotEqual"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index c3b08e067a..d949e70c66 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -435,6 +435,25 @@ REGISTER_OP("DrawBoundingBoxes")
.Output("output: T")
.Attr("T: {float, half} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
+ // The rank of images should be 4.
+ ShapeHandle images;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &images));
+ // Channel depth should be either 1 (GRY), 3 (RGB), or 4 (RGBA).
+ if (c->ValueKnown(c->Dim(images, 3))) {
+ int64 depth = c->Value(c->Dim(images, 3));
+ if (!(depth == 1 || depth == 3 || depth == 4)) {
+ return errors::InvalidArgument("Channel depth should be either 1 (GRY), "
+ "3 (RGB), or 4 (RGBA)");
+ }
+ }
+
+ // The rank of boxes is 3: [batch, num_bounding_boxes, 4].
+ ShapeHandle boxes;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &boxes));
+ // The last value of boxes shape is 4.
+ DimensionHandle unused;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 2), 4, &unused));
+
return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
});
@@ -548,7 +567,7 @@ REGISTER_OP("CropAndResize")
.Input("crop_size: int32")
.Output("crops: float")
.Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}")
- .Attr("method: {'bilinear'} = 'bilinear'")
+ .Attr("method: {'bilinear', 'nearest'} = 'bilinear'")
.Attr("extrapolation_value: float = 0")
.SetShapeFn([](InferenceContext* c) {
// Get inputs and validate ranks.
@@ -579,7 +598,7 @@ REGISTER_OP("CropAndResizeGradImage")
.Input("image_size: int32")
.Output("output: T")
.Attr("T: {float, half, double}")
- .Attr("method: {'bilinear'} = 'bilinear'")
+ .Attr("method: {'bilinear', 'nearest'} = 'bilinear'")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(3, &out));
@@ -657,4 +676,35 @@ REGISTER_OP("NonMaxSuppressionV2")
return Status::OK();
});
+REGISTER_OP("NonMaxSuppressionV3")
+ .Input("boxes: float")
+ .Input("scores: float")
+ .Input("max_output_size: int32")
+ .Input("iou_threshold: float")
+ .Input("score_threshold: float")
+ .Output("selected_indices: int32")
+ .SetShapeFn([](InferenceContext* c) {
+ // Get inputs and validate ranks.
+ ShapeHandle boxes;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
+ ShapeHandle scores;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
+ ShapeHandle max_output_size;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
+ ShapeHandle iou_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
+ ShapeHandle score_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
+ // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
+ DimensionHandle unused;
+ // The boxes[0] and scores[0] are both num_boxes.
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
+ // The boxes[1] is 4.
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
+
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
+ });
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/image_ops_test.cc b/tensorflow/core/ops/image_ops_test.cc
index 5f0b391b0d..517af26b44 100644
--- a/tensorflow/core/ops/image_ops_test.cc
+++ b/tensorflow/core/ops/image_ops_test.cc
@@ -312,4 +312,23 @@ TEST(ImageOpsTest, QuantizedResizeBilinear_ShapeFn) {
INFER_OK(op, "[1,?,3,?];[2];[];[]", "[d0_0,20,30,d0_3];[];[]");
}
+TEST(ImageOpsTest, DrawBoundingBoxes_ShapeFn) {
+ ShapeInferenceTestOp op("DrawBoundingBoxes");
+ op.input_tensors.resize(2);
+
+ // Check images.
+ INFER_ERROR("must be rank 4", op, "[1,?,3];?");
+ INFER_ERROR("should be either 1 (GRY), 3 (RGB), or 4 (RGBA)",
+ op, "[1,?,?,5];?");
+
+ // Check boxes.
+ INFER_ERROR("must be rank 3", op, "[1,?,?,4];[1,4]");
+ INFER_ERROR("Dimension must be 4", op, "[1,?,?,4];[1,2,2]");
+
+ // OK shapes.
+ INFER_OK(op, "[4,?,?,4];?", "in0");
+ INFER_OK(op, "[?,?,?,?];[?,?,?]", "in0");
+ INFER_OK(op, "[4,?,?,4];[?,?,?]", "in0");
+ INFER_OK(op, "[4,?,?,4];[?,?,4]", "in0");
+}
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 8f8443a46c..8c0b073ce4 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1017,7 +1017,7 @@ REGISTER_OP("SegmentMean")
.Input("data: T")
.Input("segment_ids: Tindices")
.Output("output: T")
- .Attr("T: realnumbertype")
+ .Attr("T: numbertype")
.Attr("Tindices: {int32,int64}")
.SetShapeFn(SegmentReductionShapeFn);
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index bb46dafd42..fc60e807b9 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -547,7 +547,7 @@ REGISTER_OP("Conv3DBackpropFilter")
});
REGISTER_OP("Conv3DBackpropInputV2")
- .Input("input_sizes: int32")
+ .Input("input_sizes: Tshape")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
@@ -556,6 +556,7 @@ REGISTER_OP("Conv3DBackpropInputV2")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
.Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
+ .Attr("Tshape: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index d741598b19..e45125a1e8 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -6242,6 +6242,7 @@ op {
allowed_values {
list {
s: "bilinear"
+ s: "nearest"
}
}
}
@@ -6347,6 +6348,7 @@ op {
allowed_values {
list {
s: "bilinear"
+ s: "nearest"
}
}
}
@@ -16568,6 +16570,33 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV3"
+ input_arg {
+ name: "boxes"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "scores"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+}
+op {
name: "NotEqual"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
index 416ce9c0d8..80ffae5796 100644
--- a/tensorflow/core/ops/random_ops.cc
+++ b/tensorflow/core/ops/random_ops.cc
@@ -72,7 +72,15 @@ REGISTER_OP("ParameterizedTruncatedNormal")
.Attr("seed2: int = 0")
.Attr("dtype: {half,bfloat16,float,double}")
.Attr("T: {int32, int64}")
- .SetShapeFn(shape_inference::RandomShape);
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ // Parameters must be 0-d or 1-d.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &unused));
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(3), 1, &unused));
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused));
+ return shape_inference::RandomShape(c);
+ });
REGISTER_OP("TruncatedNormal")
.Input("shape: T")
diff --git a/tensorflow/core/ops/rpc_ops.cc b/tensorflow/core/ops/rpc_ops.cc
index 72fda5e6eb..136f96d9ea 100644
--- a/tensorflow/core/ops/rpc_ops.cc
+++ b/tensorflow/core/ops/rpc_ops.cc
@@ -18,7 +18,6 @@ limitations under the License.
namespace tensorflow {
-using tensorflow::shape_inference::DimensionHandle;
using tensorflow::shape_inference::InferenceContext;
using tensorflow::shape_inference::ShapeHandle;
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 469f193cf4..1d5c743a56 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -37,6 +37,17 @@ REGISTER_OP("RegexReplace")
return Status::OK();
});
+REGISTER_OP("RegexFullMatch")
+ .Input("input: string")
+ .Input("pattern: string")
+ .Output("output: bool")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ });
+
REGISTER_OP("StringToHashBucketFast")
.Input("input: string")
.Output("output: int64")
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 08697a45d8..921e98c31f 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -997,6 +997,10 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket,
request->SetResultBuffer(&output_buffer);
request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
+ if (stats_ != nullptr) {
+ stats_->RecordStatObjectRequest();
+ }
+
TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(),
" when reading metadata of gs://", bucket,
"/", object);
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
index 6250aa7594..d095773770 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.h
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PLATFORM_GCS_FILE_SYSTEM_H_
-#define TENSORFLOW_CORE_PLATFORM_GCS_FILE_SYSTEM_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_FILE_SYSTEM_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_FILE_SYSTEM_H_
#include <string>
#include <utility>
@@ -56,6 +56,10 @@ class GcsStatsInterface {
virtual void RecordBlockRetrieved(const string& file, size_t offset,
size_t bytes_transferred) = 0;
+ // RecordStatObjectRequest is called once a statting object request over GCS
+ // is about to be made.
+ virtual void RecordStatObjectRequest() = 0;
+
/// HttpStats is called to optionally provide a RequestStats listener
/// to be annotated on every HTTP request made to the GCS API.
///
@@ -264,4 +268,4 @@ class RetryingGcsFileSystem : public RetryingFileSystem<GcsFileSystem> {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PLATFORM_GCS_FILE_SYSTEM_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_FILE_SYSTEM_H_
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index 28be13869b..4b594e5e61 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -2833,41 +2833,71 @@ TEST(GcsFileSystemTest, CreateHttpRequest) {
TF_EXPECT_OK(request->Send());
}
-TEST(GcsFileSystemTest, NewRandomAccessFile_StatsRecording) {
- class TestGcsStats : public GcsStatsInterface {
- public:
- void Init(GcsFileSystem* fs, GcsThrottle* throttle,
- const FileBlockCache* block_cache) override {
- CHECK(fs_ == nullptr);
- CHECK(throttle_ == nullptr);
- CHECK(block_cache_ == nullptr);
-
- fs_ = fs;
- throttle_ = throttle;
- block_cache_ = block_cache;
- }
-
- void RecordBlockLoadRequest(const string& file, size_t offset) override {
- block_load_request_file_ = file;
- }
-
- void RecordBlockRetrieved(const string& file, size_t offset,
- size_t bytes_transferred) override {
- block_retrieved_file_ = file;
- block_retrieved_bytes_transferred_ = bytes_transferred;
- }
-
- HttpRequest::RequestStats* HttpStats() override { return nullptr; }
-
- GcsFileSystem* fs_ = nullptr;
- GcsThrottle* throttle_ = nullptr;
- const FileBlockCache* block_cache_ = nullptr;
-
- string block_load_request_file_;
- string block_retrieved_file_;
- size_t block_retrieved_bytes_transferred_ = 0;
- };
+class TestGcsStats : public GcsStatsInterface {
+ public:
+ void Init(GcsFileSystem* fs, GcsThrottle* throttle,
+ const FileBlockCache* block_cache) override {
+ CHECK(fs_ == nullptr);
+ CHECK(throttle_ == nullptr);
+ CHECK(block_cache_ == nullptr);
+
+ fs_ = fs;
+ throttle_ = throttle;
+ block_cache_ = block_cache;
+ }
+
+ void RecordBlockLoadRequest(const string& file, size_t offset) override {
+ block_load_request_file_ = file;
+ }
+
+ void RecordBlockRetrieved(const string& file, size_t offset,
+ size_t bytes_transferred) override {
+ block_retrieved_file_ = file;
+ block_retrieved_bytes_transferred_ = bytes_transferred;
+ }
+
+ void RecordStatObjectRequest() override { stat_object_request_count_++; }
+
+ HttpRequest::RequestStats* HttpStats() override { return nullptr; }
+
+ GcsFileSystem* fs_ = nullptr;
+ GcsThrottle* throttle_ = nullptr;
+ const FileBlockCache* block_cache_ = nullptr;
+
+ string block_load_request_file_;
+ string block_retrieved_file_;
+ size_t block_retrieved_bytes_transferred_ = 0;
+ int stat_object_request_count_ = 0;
+};
+
+TEST(GcsFileSystemTest, Stat_StatsRecording) {
+ std::vector<HttpRequest*> requests({new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
+ "file.txt?fields=size%2Cupdated\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ strings::StrCat("{\"size\": \"1010\","
+ "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */,
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
+ TestGcsStats stats;
+ fs.SetStats(&stats);
+ EXPECT_EQ(stats.fs_, &fs);
+
+ FileStatistics stat;
+ TF_EXPECT_OK(fs.Stat("gs://bucket/file.txt", &stat));
+ EXPECT_EQ(1, stats.stat_object_request_count_);
+}
+
+TEST(GcsFileSystemTest, NewRandomAccessFile_StatsRecording) {
std::vector<HttpRequest*> requests({new FakeHttpRequest(
"Uri: https://storage.googleapis.com/bucket/random_access.txt\n"
"Auth Token: fake_token\n"
diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc
index 59ad3cbcc2..e64653a67a 100644
--- a/tensorflow/core/platform/cloud/oauth_client.cc
+++ b/tensorflow/core/platform/cloud/oauth_client.cc
@@ -97,7 +97,7 @@ Status CreateSignature(RSA* private_key, StringPiece to_sign,
}
std::unique_ptr<EVP_MD_CTX, std::function<void(EVP_MD_CTX*)>> md_ctx(
EVP_MD_CTX_create(), [](EVP_MD_CTX* ptr) { EVP_MD_CTX_destroy(ptr); });
- if (!md_ctx.get()) {
+ if (!md_ctx) {
return errors::Internal("Could not create MD_CTX.");
}
@@ -196,7 +196,7 @@ Status OAuthClient::GetTokenFromServiceAccountJson(
std::unique_ptr<RSA, std::function<void(RSA*)>> private_key(
PEM_read_bio_RSAPrivateKey(bio.get(), nullptr, nullptr, nullptr),
[](RSA* ptr) { RSA_free(ptr); });
- if (!private_key.get()) {
+ if (!private_key) {
return errors::Internal("Could not deserialize the private key.");
}
diff --git a/tensorflow/core/platform/error.h b/tensorflow/core/platform/error.h
new file mode 100644
index 0000000000..ae965b6c77
--- /dev/null
+++ b/tensorflow/core/platform/error.h
@@ -0,0 +1,30 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_ERROR_H_
+#define TENSORFLOW_CORE_PLATFORM_ERROR_H_
+
+#include "tensorflow/core/platform/platform.h"
+
+#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX) || \
+ defined(PLATFORM_POSIX_ANDROID) || defined(PLATFORM_GOOGLE_ANDROID)
+#include "tensorflow/core/platform/posix/error.h"
+#elif defined(PLATFORM_WINDOWS)
+#include "tensorflow/core/platform/windows/error.h"
+#else
+#error Define the appropriate PLATFORM_<foo> macro for this platform
+#endif
+
+#endif // TENSORFLOW_CORE_PLATFORM_ERROR_H_
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index a8cb40502c..72c12318ca 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -21,11 +21,11 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/error.h"
#include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/posix/error.h"
#include "third_party/hadoop/hdfs.h"
namespace tensorflow {
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index 5372ef24b8..bcc7ac1804 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -47,6 +47,9 @@ message RewriterConfig {
// Statically infer the value of tensors when possible, and materialize the
// result using constants.
Toggle constant_folding = 3;
+ // Shape optimizations (default is OFF)
+ // Simplify computations made on shapes;
+ Toggle shape_optimization = 13;
// Arithmetic optimizations (default is ON)
// e.g. Simplify arithmetic ops; merge ops with same value (like constants).
Toggle arithmetic_optimization = 7;
diff --git a/tensorflow/core/protobuf/transport_options.proto b/tensorflow/core/protobuf/transport_options.proto
new file mode 100644
index 0000000000..d7b1bddbbe
--- /dev/null
+++ b/tensorflow/core/protobuf/transport_options.proto
@@ -0,0 +1,8 @@
+syntax = "proto3";
+
+package tensorflow;
+
+// Extra data needed on a non-RDMA RecvBufResponse.
+message RecvBufRespExtra {
+ bytes tensor_content = 1;
+};
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index d714d85ce6..b68d970876 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -418,6 +418,60 @@ message TracingResponse {
////////////////////////////////////////////////////////////////////////////////
//
+// Raw data transfers in support of Collective Ops.
+// These methods are experimental and subject to change.
+//
+// The intention is to allow collectives to take advantage of the most
+// efficient methods available on a platform, e.g. RDMA, and not be
+// constrained to use the RPC system in use by other methods.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message RecvBufRequest {
+ // Use of the fields below may vary by implementation. For example
+ // the buf_ptr and num_bytes may be set only for local operations and
+ // not sent on the wire, or only sent on the wire in one direction.
+
+ // Used at server side to find the correct BufRendezvous.
+ int64 step_id = 1;
+
+ // Arbitrary string identifying a BufRendezvous entry.
+ string buf_rendezvous_key = 2;
+
+ // Size of value expected, must agree with BufRendezvous entry.
+ int64 num_bytes = 3;
+
+ // When RDMA is in use, address of destination field on client.
+ fixed64 buf_ptr = 4;
+
+ // Optional information on client-side device locality.
+ DeviceLocality client_locality = 5;
+
+ // Optional information on server-side device locality.
+ DeviceLocality server_locality = 6;
+
+ // Optional, implementation-specific data.
+ google.protobuf.Any transport_options = 7;
+ // Optional, for annotating the timeline.
+ string src_device = 8;
+ string dst_device = 9;
+}
+
+message RecvBufResponse {
+ // Use of the fields below may vary by implementation. Comments give
+ // intended use.
+
+ fixed64 buf_ptr = 1; // Address of source field on server.
+ int64 num_bytes = 2; // Byte length of buf_ptr field, if set.
+ bool is_dead = 3; // True if value is 'dead' like a tensor.
+ // Optional, implementation-specific data.
+ google.protobuf.Any transport_options = 4;
+ // Optional, for timeline.
+ int64 send_start_micros = 5;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
// Collective Op dynamic group resolution messages.
//
////////////////////////////////////////////////////////////////////////////////
diff --git a/tensorflow/core/protobuf/worker_service.proto b/tensorflow/core/protobuf/worker_service.proto
index 025fa7ca59..9ebbd553f2 100644
--- a/tensorflow/core/protobuf/worker_service.proto
+++ b/tensorflow/core/protobuf/worker_service.proto
@@ -74,6 +74,10 @@ service WorkerService {
rpc Tracing(TracingRequest) returns (TracingResponse);
// See worker.proto for details.
+ rpc RecvBuf(RecvBufRequest) returns (RecvBufResponse) {
+ }
+
+ // See worker.proto for details.
rpc GetStepSequence(GetStepSequenceRequest) returns (GetStepSequenceResponse);
// See worker.proto for details.
diff --git a/tensorflow/docs_src/community/swift.md b/tensorflow/docs_src/community/swift.md
index e5a0f02a8c..d1625d3b93 100644
--- a/tensorflow/docs_src/community/swift.md
+++ b/tensorflow/docs_src/community/swift.md
@@ -8,7 +8,7 @@ Welcome to the Swift for TensorFlow development community!
Swift for TensorFlow is a new way to develop machine learning models. It
gives you the power of
-[TensorFlow](programmers_guide/eager) directly
+[TensorFlow](https://www.tensorflow.org) directly
integrated into the [Swift programming language](https://swift.org/about).
With Swift, you can write the following imperative code, and Swift
automatically turns it into **a single TensorFlow Graph** and runs it
@@ -18,7 +18,7 @@ with the full performance of TensorFlow Sessions on CPU, GPU and
```swift
import TensorFlow
-var x = Tensor([[1, 2], [3, 4]])
+var x = Tensor<Float>([[1, 2], [3, 4]])
for i in 1...5 {
x += x ⊗ x
@@ -28,8 +28,8 @@ print(x)
```
Swift combines the flexibility of
-[Eager Execution](programmers_guide/eager) with the
-high performance of [Graphs and Sessions](programmers_guide/graphs).
+[Eager Execution](https://www.tensorflow.org/programmers_guide/eager) with the
+high performance of [Graphs and Sessions](https://www.tensorflow.org/programmers_guide/graphs).
Behind the scenes, Swift analyzes your Tensor code and automatically builds
graphs for you. Swift also catches type errors and shape mismatches before
running your code, and has [Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
diff --git a/tensorflow/docs_src/extend/adding_an_op.md b/tensorflow/docs_src/extend/adding_an_op.md
index c3795492ce..1b028be4ea 100644
--- a/tensorflow/docs_src/extend/adding_an_op.md
+++ b/tensorflow/docs_src/extend/adding_an_op.md
@@ -863,48 +863,53 @@ REGISTER_OP("ZeroOut")
Instead of writing another `OpKernel` with redundant code as above, often you
will be able to use a C++ template instead. You will still have one kernel
registration (`REGISTER_KERNEL_BUILDER` call) per overload.
-<pre class="prettyprint"><code class="lang-cpp">
-<b>template &lt;typename T&gt;</b>
+```c++
+template <typename T>
class ZeroOutOp : public OpKernel {
public:
- explicit ZeroOutOp(OpKernelConstruction\* context) : OpKernel(context) {}<br/>
- void Compute(OpKernelContext\* context) override {
+ explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
// Grab the input tensor
- const Tensor& input\_tensor = context-&gt;input(0);
- auto input = input\_tensor.flat<b>&lt;T&gt;</b>();<br/>
+ const Tensor& input_tensor = context->input(0);
+ auto input = input_tensor.flat<T>();
+
// Create an output tensor
Tensor* output = NULL;
- OP\_REQUIRES\_OK(context,
- context-&gt;allocate\_output(0, input_tensor.shape(), &output));
- auto output\_flat = output-&gt;template flat<b>&lt;T&gt;</b>();<br/>
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input_tensor.shape(), &output));
+ auto output_flat = output->template flat<T>();
+
// Set all the elements of the output tensor to 0
const int N = input.size();
- for (int i = 0; i &lt; N; i++) {
- output\_flat(i) = 0;
- }<br/>
+ for (int i = 0; i < N; i++) {
+ output_flat(i) = 0;
+ }
+
// Preserve the first input value
- if (N &gt; 0) output\_flat(0) = input(0);
+ if (N > 0) output_flat(0) = input(0);
}
-};<br/>
-// Note that TypeConstraint&lt;int32&gt;("T") means that attr "T" (defined
+};
+
+// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
-// instantiation.</b>
-REGISTER\_KERNEL\_BUILDER(
+// instantiation.
+REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
- .Device(DEVICE\_CPU)
- .TypeConstraint&lt;int32&gt;("T"),
- <b>ZeroOutOp&lt;int32&gt;</b>);
-REGISTER\_KERNEL\_BUILDER(
+ .Device(DEVICE_CPU)
+ .TypeConstraint<int32>("T"),
+ ZeroOutOp<int32>);
+REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
- .Device(DEVICE\_CPU)
- .TypeConstraint&lt;float&gt;("T"),
- <b>ZeroOutOp&lt;float&gt;</b>);
-<b>REGISTER\_KERNEL\_BUILDER(
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ ZeroOutOp<float>);
+REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
- .Device(DEVICE\_CPU)
- .TypeConstraint&lt;double&gt;("T"),
- ZeroOutOp&lt;double&gt;);
-</b></code></pre>
+ .Device(DEVICE_CPU)
+ .TypeConstraint<double>("T"),
+ ZeroOutOp<double>);
+```
If you have more than a couple overloads, you can put the registration in a
macro.
diff --git a/tensorflow/docs_src/mobile/tflite/index.md b/tensorflow/docs_src/mobile/tflite/index.md
index 01881ccf3b..5622034827 100644
--- a/tensorflow/docs_src/mobile/tflite/index.md
+++ b/tensorflow/docs_src/mobile/tflite/index.md
@@ -155,7 +155,7 @@ retraining for both floating point and quantized inference.
The following diagram shows the architectural design of TensorFlow Lite:
-<img src="/images/tflite-architecture.jpg"
+<img src="https://www.tensorflow.org/images/tflite-architecture.jpg"
alt="TensorFlow Lite architecture diagram"
style="max-width:600px;">
diff --git a/tensorflow/docs_src/performance/xla/broadcasting.md b/tensorflow/docs_src/performance/xla/broadcasting.md
index ca3bddf758..eaa709c2f8 100644
--- a/tensorflow/docs_src/performance/xla/broadcasting.md
+++ b/tensorflow/docs_src/performance/xla/broadcasting.md
@@ -97,9 +97,9 @@ shape is broadcast into a larger rank shape. For example, given a 2x3x4 cuboid
and a 3x4 matrix, a broadcasting tuple (1,2) means matching the matrix to
dimensions 1 and 2 of the cuboid.
-This type of broadcast is used in the binary ops in `ComputationBuilder`, if the
+This type of broadcast is used in the binary ops in `XlaBuilder`, if the
`broadcast_dimensions` argument is given. For example, see
-[ComputationBuilder::Add](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.cc).
+[XlaBuilder::Add](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.cc).
In the XLA source code, this type of broadcasting is sometimes called "InDim"
broadcasting.
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 21e4c71a60..5887c3d88b 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -1,7 +1,7 @@
# Operation Semantics
The following describes the semantics of operations defined in the
-[`ComputationBuilder`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h)
+[`XlaBuilder`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h)
interface. Typically, these operations map one-to-one to operations defined in
the RPC interface in
[`xla_data.proto`](https://www.tensorflow.org/code/tensorflow/compiler/xla/xla_data.proto).
@@ -16,7 +16,7 @@ and familiar names; for example a *vector* is a 1-dimensional array and a
## BatchNormGrad
See also
-[`ComputationBuilder::BatchNormGrad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h)
+[`XlaBuilder::BatchNormGrad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h)
and [the original batch normalization paper](https://arxiv.org/abs/1502.03167)
for a detailed description of the algorithm.
@@ -26,14 +26,14 @@ Calculates gradients of batch norm.
| Arguments | Type | Semantics |
| --------------- | ----------------------- | -------------------------------- |
-| `operand` | `ComputationDataHandle` | n dimensional array to be |
+| `operand` | `XlaOp` | n dimensional array to be |
: : : normalized (x) :
-| `scale` | `ComputationDataHandle` | 1 dimensional array |
+| `scale` | `XlaOp` | 1 dimensional array |
: : : (\\(\gamma\\)) :
-| `mean` | `ComputationDataHandle` | 1 dimensional array (\\(\mu\\)) |
-| `variance` | `ComputationDataHandle` | 1 dimensional array |
+| `mean` | `XlaOp` | 1 dimensional array (\\(\mu\\)) |
+| `variance` | `XlaOp` | 1 dimensional array |
: : : (\\(\sigma^2\\)) :
-| `grad_output` | `ComputationDataHandle` | Gradients passed to |
+| `grad_output` | `XlaOp` | Gradients passed to |
: : : `BatchNormTraining` :
: : : (\\( \nabla y\\)) :
| `epsilon` | `float` | Epsilon value (\\(\epsilon\\)) |
@@ -70,35 +70,33 @@ The output type is a tuple of three handles:
| Outputs | Type | Semantics |
| ------------- | ----------------------- | --------------------------------- |
-| `grad_operand` | `ComputationDataHandle` | gradient with respect to input |
+| `grad_operand` | `XlaOp` | gradient with respect to input |
: : : `operand` (\\( \nabla x\\)) :
-| `grad_scale` | `ComputationDataHandle` | gradient with respect to input |
+| `grad_scale` | `XlaOp` | gradient with respect to input |
: : : `scale` (\\( \nabla \gamma\\)) :
-| `grad_offset` | `ComputationDataHandle` | gradient with respect to input |
+| `grad_offset` | `XlaOp` | gradient with respect to input |
: : : `offset`(\\( \nabla \beta\\)) :
## BatchNormInference
See also
-[`ComputationBuilder::BatchNormInference`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h) and
-[the original batch normalization paper](https://arxiv.org/abs/1502.03167)
+[`XlaBuilder::BatchNormInference`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h)
+and [the original batch normalization paper](https://arxiv.org/abs/1502.03167)
for a detailed description of the algorithm.
Normalizes an array across batch and spatial dimensions.
<b> `BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)` </b>
-| Arguments | Type | Semantics |
-| -------------- | ----------------------- | ------------------------------- |
-| `operand` | `ComputationDataHandle` | n dimensional array to be |
-: : : normalized :
-| `scale` | `ComputationDataHandle` | 1 dimensional array |
-| `offset` | `ComputationDataHandle` | 1 dimensional array |
-| `mean` | `ComputationDataHandle` | 1 dimensional array |
-| `variance` | `ComputationDataHandle` | 1 dimensional array |
-| `epsilon` | `float` | Epsilon value |
-| `feature_index` | `int64` | Index to feature dimension in |
-: : : `operand` :
+Arguments | Type | Semantics
+--------------- | ------- | ---------------------------------------
+`operand` | `XlaOp` | n dimensional array to be normalized
+`scale` | `XlaOp` | 1 dimensional array
+`offset` | `XlaOp` | 1 dimensional array
+`mean` | `XlaOp` | 1 dimensional array
+`variance` | `XlaOp` | 1 dimensional array
+`epsilon` | `float` | Epsilon value
+`feature_index` | `int64` | Index to feature dimension in `operand`
For each feature in the feature dimension (`feature_index` is the index for the
feature dimension in `operand`), the operation calculates the mean and variance
@@ -117,25 +115,21 @@ The output is an n-dimensional, normalized array with the same shape as input
## BatchNormTraining
See also
-[`ComputationBuilder::BatchNormTraining`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h) and
-[`the original batch normalization paper`](https://arxiv.org/abs/1502.03167)
+[`XlaBuilder::BatchNormTraining`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h)
+and [`the original batch normalization paper`](https://arxiv.org/abs/1502.03167)
for a detailed description of the algorithm.
Normalizes an array across batch and spatial dimensions.
<b> `BatchNormTraining(operand, scale, offset, epsilon, feature_index)` </b>
-| Arguments | Type | Semantics |
-| --------------- | ----------------------- | -------------------------------- |
-| `operand` | `ComputationDataHandle` | n dimensional array to be |
-: : : normalized (x) :
-| `scale` | `ComputationDataHandle` | 1 dimensional array |
-: : : (\\(\gamma\\)) :
-| `offset` | `ComputationDataHandle` | 1 dimensional array |
-: : : (\\(\beta\\)) :
-| `epsilon` | `float` | Epsilon value (\\(\epsilon\\)) |
-| `feature_index` | `int64` | Index to feature dimension |
-: : : in `operand` :
+Arguments | Type | Semantics
+--------------- | ------- | ----------------------------------------
+`operand` | `XlaOp` | n dimensional array to be normalized (x)
+`scale` | `XlaOp` | 1 dimensional array (\\(\gamma\\))
+`offset` | `XlaOp` | 1 dimensional array (\\(\beta\\))
+`epsilon` | `float` | Epsilon value (\\(\epsilon\\))
+`feature_index` | `int64` | Index to feature dimension in `operand`
For each feature in the feature dimension (`feature_index` is the index for the
feature dimension in `operand`), the operation calculates the mean and variance
@@ -158,14 +152,14 @@ contains `m` elements with `w` and `h` as the size of spatial dimensions
The epsilon value, usually a small number, is added to avoid divide-by-zero errors.
-The output type is a tuple of three `ComputationDataHandle`s:
+The output type is a tuple of three `XlaOp`s:
| Outputs | Type | Semantics |
| ------------ | ----------------------- | -------------------------------------|
-| `output` | `ComputationDataHandle` | n dimensional array with the same |
+| `output` | `XlaOp` | n dimensional array with the same |
: : : shape as input `operand` (y) :
-| `batch_mean` | `ComputationDataHandle` | 1 dimensional array (\\(\mu\\)) |
-| `batch_var` | `ComputationDataHandle` | 1 dimensional array (\\(\sigma^2\\)) |
+| `batch_mean` | `XlaOp` | 1 dimensional array (\\(\mu\\)) |
+| `batch_var` | `XlaOp` | 1 dimensional array (\\(\sigma^2\\)) |
The `batch_mean` and `batch_var` are moments calculated across the batch and
spatial dimensions using the formulas above.
@@ -173,7 +167,7 @@ spatial dimensions using the formulas above.
## BitcastConvertType
See also
-[`ComputationBuilder::BitcastConvertType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::BitcastConvertType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Similar to a `tf.bitcast` in TensorFlow, performs an element-wise bitcast
operation from a data shape to a target shape. The dimensions must match, and
@@ -183,10 +177,10 @@ with different floating-point representations will give different results.
<b> `BitcastConvertType(operand, new_element_type)` </b>
-Arguments | Type | Semantics
------------------- | ----------------------- | ---------------------------
-`operand` | `ComputationDataHandle` | array of type T with dims D
-`new_element_type` | `PrimitiveType` | type U
+Arguments | Type | Semantics
+------------------ | --------------- | ---------------------------
+`operand` | `XlaOp` | array of type T with dims D
+`new_element_type` | `PrimitiveType` | type U
The dimensions of the operand and the target shape must match. The bit-width of
the source and destination element types must be equal. The source
@@ -195,16 +189,16 @@ and destination element types must not be tuples.
## Broadcast
See also
-[`ComputationBuilder::Broadcast`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Broadcast`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Adds dimensions to an array by duplicating the data in the array.
<b> `Broadcast(operand, broadcast_sizes)` </b>
-Arguments | Type | Semantics
------------------ | ----------------------- | -------------------------------
-`operand` | `ComputationDataHandle` | The array to duplicate
-`broadcast_sizes` | `ArraySlice<int64>` | The sizes of the new dimensions
+Arguments | Type | Semantics
+----------------- | ------------------- | -------------------------------
+`operand` | `XlaOp` | The array to duplicate
+`broadcast_sizes` | `ArraySlice<int64>` | The sizes of the new dimensions
The new dimensions are inserted on the left, i.e. if `broadcast_sizes` has
values `{a0, ..., aN}` and the operand shape has dimensions `{b0, ..., bM}` then
@@ -223,19 +217,18 @@ For example, if `operand` is a scalar `f32` with value `2.0f`, and
## Call
See also
-[`ComputationBuilder::Call`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Call`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Invokes a computation with the given arguments.
<b> `Call(computation, args...)` </b>
-| Arguments | Type | Semantics |
-| ------------- | ------------------------ | -------------------------------- |
-| `computation` | `Computation` | computation of type `T_0, T_1, |
-: : : ..., T_N -> S` with N parameters :
-: : : of arbitrary type :
-| `args` | sequence of N | N arguments of arbitrary type |
-: : `ComputationDataHandle`s : :
+| Arguments | Type | Semantics |
+| ------------- | ---------------------- | ----------------------------------- |
+| `computation` | `XlaComputation` | computation of type `T_0, T_1, ..., |
+: : : T_N -> S` with N parameters of :
+: : : arbitrary type :
+| `args` | sequence of N `XlaOp`s | N arguments of arbitrary type |
The arity and types of the `args` must match the parameters of the
`computation`. It is allowed to have no `args`.
@@ -243,17 +236,17 @@ The arity and types of the `args` must match the parameters of the
## Clamp
See also
-[`ComputationBuilder::Clamp`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Clamp`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Clamps an operand to within the range between a minimum and maximum value.
<b> `Clamp(min, operand, max)` </b>
-| Arguments | Type | Semantics |
-| ------------- | ----------------------- | -------------------------------- |
-| `min` | `ComputationDataHandle` | array of type T |
-| `operand` | `ComputationDataHandle` | array of type T |
-| `max` | `ComputationDataHandle` | array of type T |
+Arguments | Type | Semantics
+--------- | ------- | ---------------
+`min` | `XlaOp` | array of type T
+`operand` | `XlaOp` | array of type T
+`max` | `XlaOp` | array of type T
Given an operand and minimum and maximum values, returns the operand if it is in
the range between the minimum and maximum, else returns the minimum value if the
@@ -276,18 +269,17 @@ Clamp(min, operand, max) = s32[3]{0, 5, 6};
## Collapse
See also
-[`ComputationBuilder::Collapse`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h)
+[`XlaBuilder::Collapse`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h)
and the @{tf.reshape} operation.
Collapses dimensions of an array into one dimension.
<b> `Collapse(operand, dimensions)` </b>
-| Arguments | Type | Semantics |
-| ------------ | ----------------------- | ----------------------------------- |
-| `operand` | `ComputationDataHandle` | array of type T |
-| `dimensions` | `int64` vector | in-order, consecutive subset of T's |
-: : : dimensions. :
+Arguments | Type | Semantics
+------------ | -------------- | -----------------------------------------------
+`operand` | `XlaOp` | array of type T
+`dimensions` | `int64` vector | in-order, consecutive subset of T's dimensions.
Collapse replaces the given subset of the operand's dimensions by a single
dimension. The input arguments are an arbitrary array of type T and a
@@ -340,7 +332,7 @@ then v12 == f32[8x3] {{10, 11, 12},
## Concatenate
See also
-[`ComputationBuilder::ConcatInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::ConcatInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Concatenate composes an array from multiple array operands. The array is of the
same rank as each of the input array operands (which must be of the same rank as
@@ -348,13 +340,13 @@ each other) and contains the arguments in the order that they were specified.
<b> `Concatenate(operands..., dimension)` </b>
-| Arguments | Type | Semantics |
-| ----------- | ----------------------- | ------------------------------------ |
-| `operands` | sequence of N | N arrays of type T with dimensions |
-: : `ComputationDataHandle` : [L0, L1, ...]. Requires N >= 1. :
-| `dimension` | `int64` | A value in the interval `[0, N)` |
-: : : that names the dimension to be :
-: : : concatenated between the `operands`. :
+| Arguments | Type | Semantics |
+| ----------- | --------------------- | -------------------------------------- |
+| `operands` | sequence of N `XlaOp` | N arrays of type T with dimensions |
+: : : [L0, L1, ...]. Requires N >= 1. :
+| `dimension` | `int64` | A value in the interval `[0, N)` that |
+: : : names the dimension to be concatenated :
+: : : between the `operands`. :
With the exception of `dimension` all dimensions must be the same. This is
because XLA does not support "ragged" arrays. Also note that rank-0 values
@@ -395,20 +387,19 @@ Diagram:
## Conditional
-See also [`ComputationBuilder::Conditional`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+See also
+[`XlaBuilder::Conditional`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
<b> `Conditional(pred, true_operand, true_computation, false_operand,
- false_computation)` </b>
-
-| Arguments | Type | Semantics |
-| ------------------- | ----------------------- | --------------------------- |
-| `pred` | `ComputationDataHandle` | Scalar of type `PRED` |
-| `true_operand` | `ComputationDataHandle` | Argument of type `T_0` |
-| `true_computation` | `Computation` | Computation of type `T_0 -> |
-: : : S` :
-| `false_operand` | `ComputationDataHandle` | Argument of type `T_1` |
-| `false_computation` | `Computation` | Computation of type `T_1 -> |
-: : : S` :
+false_computation)` </b>
+
+Arguments | Type | Semantics
+------------------- | ---------------- | ---------------------------------
+`pred` | `XlaOp` | Scalar of type `PRED`
+`true_operand` | `XlaOp` | Argument of type `T_0`
+`true_computation` | `XlaComputation` | XlaComputation of type `T_0 -> S`
+`false_operand` | `XlaOp` | Argument of type `T_1`
+`false_computation` | `XlaComputation` | XlaComputation of type `T_1 -> S`
Executes `true_computation` if `pred` is `true`, `false_computation` if `pred`
is `false`, and returns the result.
@@ -425,7 +416,7 @@ executed depending on the value of `pred`.
## Conv (convolution)
See also
-[`ComputationBuilder::Conv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Conv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
As ConvWithGeneralPadding, but the padding is specified in a short-hand way as
either SAME or VALID. SAME padding pads the input (`lhs`) with zeroes so that
@@ -435,7 +426,7 @@ account. VALID padding simply means no padding.
## ConvWithGeneralPadding (convolution)
See also
-[`ComputationBuilder::ConvWithGeneralPadding`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::ConvWithGeneralPadding`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Computes a convolution of the kind used in neural networks. Here, a convolution
can be thought of as a n-dimensional window moving across a n-dimensional base
@@ -443,8 +434,8 @@ area and a computation is performed for each possible position of the window.
| Arguments | Type | Semantics |
| ---------------- | ----------------------- | ----------------------------- |
-| `lhs` | `ComputationDataHandle` | rank n+2 array of inputs |
-| `rhs` | `ComputationDataHandle` | rank n+2 array of kernel |
+| `lhs` | `XlaOp` | rank n+2 array of inputs |
+| `rhs` | `XlaOp` | rank n+2 array of kernel |
: : : weights :
| `window_strides` | `ArraySlice<int64>` | n-d array of kernel strides |
| `padding` | `ArraySlice<pair<int64, | n-d array of (low, high) |
@@ -547,7 +538,7 @@ for (b, oz, oy, ox) { // output coordinates
## ConvertElementType
See also
-[`ComputationBuilder::ConvertElementType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::ConvertElementType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Similar to an element-wise `static_cast` in C++, performs an element-wise
conversion operation from a data shape to a target shape. The dimensions must
@@ -556,10 +547,10 @@ match, and the conversion is an element-wise one; e.g. `s32` elements become
<b> `ConvertElementType(operand, new_element_type)` </b>
-Arguments | Type | Semantics
------------------- | ----------------------- | ---------------------------
-`operand` | `ComputationDataHandle` | array of type T with dims D
-`new_element_type` | `PrimitiveType` | type U
+Arguments | Type | Semantics
+------------------ | --------------- | ---------------------------
+`operand` | `XlaOp` | array of type T with dims D
+`new_element_type` | `PrimitiveType` | type U
The dimensions of the operand and the target shape must match. The source and
destination element types must not be tuples.
@@ -581,15 +572,15 @@ then b == f32[3]{0.0, 1.0, 2.0}
## CrossReplicaSum
See also
-[`ComputationBuilder::CrossReplicaSum`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::CrossReplicaSum`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Computes a sum across replicas.
<b> `CrossReplicaSum(operand)` </b>
-| Arguments | Type | Semantics |
-| ------------ | ----------------------- | ---------------------------------- |
-| `operand` | `ComputationDataHandle` | Array to sum across replicas. |
+Arguments | Type | Semantics
+--------- | ------- | -----------------------------
+`operand` | `XlaOp` | Array to sum across replicas.
The output shape is the same as the input shape. For example, if there are two
replicas and the operand has the value `(1.0, 2.5)` and `(3.0, 5.25)`
@@ -607,21 +598,21 @@ than another.
## CustomCall
See also
-[`ComputationBuilder::CustomCall`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::CustomCall`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Call a user-provided function within a computation.
<b> `CustomCall(target_name, args..., shape)` </b>
-| Arguments | Type | Semantics |
-| ------------- | ------------------------ | -------------------------------- |
-| `target_name` | `string` | Name of the function. A call |
-: : : instruction will be emitted :
-: : : which targets this symbol name. :
-| `args` | sequence of N | N arguments of arbitrary type, |
-: : `ComputationDataHandle`s : which will be passed to the :
-: : : function. :
-| `shape` | `Shape` | Output shape of the function |
+| Arguments | Type | Semantics |
+| ------------- | ---------------------- | --------------------------------- |
+| `target_name` | `string` | Name of the function. A call |
+: : : instruction will be emitted which :
+: : : targets this symbol name. :
+| `args` | sequence of N `XlaOp`s | N arguments of arbitrary type, |
+: : : which will be passed to the :
+: : : function. :
+| `shape` | `Shape` | Output shape of the function |
The function signature is the same, regardless of the arity or type of args:
@@ -668,14 +659,14 @@ idempotent.
## Dot
See also
-[`ComputationBuilder::Dot`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Dot`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
<b> `Dot(lhs, rhs)` </b>
-Arguments | Type | Semantics
---------- | ----------------------- | ---------------
-`lhs` | `ComputationDataHandle` | array of type T
-`rhs` | `ComputationDataHandle` | array of type T
+Arguments | Type | Semantics
+--------- | ------- | ---------------
+`lhs` | `XlaOp` | array of type T
+`rhs` | `XlaOp` | array of type T
The exact semantics of this operation depend on the ranks of the operands:
@@ -697,15 +688,15 @@ multiplications or matrix/matrix multiplications.
## DotGeneral
See also
-[`ComputationBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
<b> `DotGeneral(lhs, rhs, dimension_numbers)` </b>
-| Arguments | Type | Semantics
-| --------- | ----------------------- | ---------------
-| `lhs` | `ComputationDataHandle` | array of type T
-| `rhs` | `ComputationDataHandle` | array of type T
-| `dimension_numbers` | `DotDimensionNumbers` | array of type T
+Arguments | Type | Semantics
+------------------- | --------------------- | ---------------
+`lhs` | `XlaOp` | array of type T
+`rhs` | `XlaOp` | array of type T
+`dimension_numbers` | `DotDimensionNumbers` | array of type T
As Dot, but allows contracting and batch dimension numbers to be specified for
both the 'lhs' and 'rhs'.
@@ -784,7 +775,7 @@ non-contracting/non-batch dimension.
## DynamicSlice
See also
-[`ComputationBuilder::DynamicSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::DynamicSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
DynamicSlice extracts a sub-array from the input array at dynamic
`start_indices`. The size of the slice in each dimension is passed in
@@ -796,22 +787,21 @@ calculation of 'start_indices') is currently implementation-defined.
<b> `DynamicSlice(operand, start_indices, size_indices)` </b>
-| Arguments | Type | Semantics |
-| --------------- | ----------------------- | -------------------------------- |
-| `operand` | `ComputationDataHandle` | N dimensional array of type T |
-| `start_indices` | `ComputationDataHandle` | Rank 1 array of N integers |
-: : : containing the starting indices :
-: : : of the slice for each dimension. :
-: : : Value must be greater than or :
-: : : equal to zero. :
-| `size_indices` | `ArraySlice<int64>` | List of N integers containing |
-: : : the slice size for each :
-: : : dimension. Each value must be :
-: : : strictly greater than zero, and :
-: : : start + size must be less than :
-: : : or equal to the size of the :
-: : : dimension to avoid wrapping :
-: : : modulo dimension size. :
+| Arguments | Type | Semantics |
+| --------------- | ------------------- | ----------------------------------- |
+| `operand` | `XlaOp` | N dimensional array of type T |
+| `start_indices` | `XlaOp` | Rank 1 array of N integers |
+: : : containing the starting indices of :
+: : : the slice for each dimension. Value :
+: : : must be greater than or equal to :
+: : : zero. :
+| `size_indices` | `ArraySlice<int64>` | List of N integers containing the |
+: : : slice size for each dimension. Each :
+: : : value must be strictly greater than :
+: : : zero, and start + size must be less :
+: : : than or equal to the size of the :
+: : : dimension to avoid wrapping modulo :
+: : : dimension size. :
1-dimensional example:
@@ -840,7 +830,7 @@ DynamicSlice(b, s, {2, 2}) produces:
## DynamicUpdateSlice
See also
-[`ComputationBuilder::DynamicUpdateSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::DynamicUpdateSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
DynamicUpdateSlice generates a result which is the value of the input array
`operand`, with a slice `update` overwritten at `start_indices`.
@@ -853,23 +843,19 @@ calculation of 'start_indices') is currently implementation-defined.
<b> `DynamicUpdateSlice(operand, update, start_indices)` </b>
-| Arguments | Type | Semantics |
-| --------------- | ----------------------- | -------------------------------- |
-| `operand` | `ComputationDataHandle` | N dimensional array of type T |
-| `update` | `ComputationDataHandle` | N dimensional array of type T |
-: : : containing the slice update. :
-: : : Each dimension of update shape :
-: : : must be strictly greater than :
-: : : zero, and start + update must be :
-: : : less than or equal to the operand:
-: : : size for each dimension to avoid :
-: : : generating out-of-bounds update :
-: : : indices. :
-| `start_indices` | `ComputationDataHandle` | Rank 1 array of N integers |
-: : : containing the starting indices :
-: : : of the slice for each dimension. :
-: : : Value must be greater than or :
-: : : equal to zero. :
+| Arguments | Type | Semantics |
+| --------------- | ------- | ------------------------------------------------ |
+| `operand` | `XlaOp` | N dimensional array of type T |
+| `update` | `XlaOp` | N dimensional array of type T containing the |
+: : : slice update. Each dimension of update shape :
+: : : must be strictly greater than zero, and start + :
+: : : update must be less than or equal to the operand :
+: : : size for each dimension to avoid generating :
+: : : out-of-bounds update indices. :
+| `start_indices` | `XlaOp` | Rank 1 array of N integers containing the |
+: : : starting indices of the slice for each :
+: : : dimension. Value must be greater than or equal :
+: : : to zero. :
1-dimensional example:
@@ -907,7 +893,7 @@ DynamicUpdateSlice(b, u, s) produces:
## Element-wise binary arithmetic operations
See also
-[`ComputationBuilder::Add`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Add`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
A set of element-wise binary arithmetic operations is supported.
@@ -917,10 +903,10 @@ Where `Op` is one of `Add` (addition), `Sub` (subtraction), `Mul`
(multiplication), `Div` (division), `Rem` (remainder), `Max` (maximum), `Min`
(minimum), `LogicalAnd` (logical AND), or `LogicalOr` (logical OR).
-Arguments | Type | Semantics
---------- | ----------------------- | ----------------------------------------
-`lhs` | `ComputationDataHandle` | left-hand-side operand: array of type T
-`rhs` | `ComputationDataHandle` | right-hand-side operand: array of type T
+Arguments | Type | Semantics
+--------- | ------- | ----------------------------------------
+`lhs` | `XlaOp` | left-hand-side operand: array of type T
+`rhs` | `XlaOp` | right-hand-side operand: array of type T
The arguments' shapes have to be either similar or compatible. See the
@{$broadcasting$broadcasting} documentation about what it means for shapes to
@@ -952,7 +938,7 @@ shapes of both operands. The semantics are described in detail on the
## Element-wise comparison operations
See also
-[`ComputationBuilder::Eq`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Eq`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
A set of standard element-wise binary comparison operations is supported. Note
that standard IEEE 754 floating-point comparison semantics apply when comparing
@@ -964,10 +950,10 @@ Where `Op` is one of `Eq` (equal-to), `Ne` (not equal-to), `Ge`
(greater-or-equal-than), `Gt` (greater-than), `Le` (less-or-equal-than), `Lt`
(less-than).
-Arguments | Type | Semantics
---------- | ----------------------- | ----------------------------------------
-`lhs` | `ComputationDataHandle` | left-hand-side operand: array of type T
-`rhs` | `ComputationDataHandle` | right-hand-side operand: array of type T
+Arguments | Type | Semantics
+--------- | ------- | ----------------------------------------
+`lhs` | `XlaOp` | left-hand-side operand: array of type T
+`rhs` | `XlaOp` | right-hand-side operand: array of type T
The arguments' shapes have to be either similar or compatible. See the
@{$broadcasting$broadcasting} documentation about what it means for shapes to
@@ -991,7 +977,7 @@ in detail on the @{$broadcasting$broadcasting page}.
## Element-wise unary functions
-ComputationBuilder supports these element-wise unary functions:
+XlaBuilder supports these element-wise unary functions:
<b>`Abs(operand)`</b> Element-wise abs `x -> |x|`.
@@ -1023,9 +1009,9 @@ using the comparison operator of the element type of `operand`.
<b>`Tanh(operand)`</b> Element-wise hyperbolic tangent `x -> tanh(x)`.
-Arguments | Type | Semantics
---------- | ----------------------- | ---------------------------
-`operand` | `ComputationDataHandle` | The operand to the function
+Arguments | Type | Semantics
+--------- | ------- | ---------------------------
+`operand` | `XlaOp` | The operand to the function
The function is applied to each element in the `operand` array, resulting in an
array with the same shape. It is allowed for `operand` to be a scalar (rank 0).
@@ -1038,16 +1024,16 @@ potentially different runtime offset) of an input tensor into an output tensor.
### General Semantics
See also
-[`ComputationBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
For a more intuitive description, see the "Informal Description" section below.
<b> `gather(operand, gather_indices, output_window_dims, elided_window_dims, window_bounds, gather_dims_to_operand_dims)` </b>
|Arguments | Type | Semantics |
|----------------- | ----------------------- | --------------------------------|
-|`operand` | `ComputationDataHandle` | The tensor we’re gathering |
+|`operand` | `XlaOp` | The tensor we’re gathering |
: : : from. :
-|`gather_indices` | `ComputationDataHandle` | Tensor containing the starting |
+|`gather_indices` | `XlaOp` | Tensor containing the starting |
: : : indices of the slices we're :
: : : stitching together into the :
: : : output tensor. :
@@ -1241,7 +1227,7 @@ concatenation of all these rows.
## GetTupleElement
See also
-[`ComputationBuilder::GetTupleElement`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::GetTupleElement`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Indexes into a tuple with a compile-time-constant value.
@@ -1262,7 +1248,7 @@ See also @{tf.tuple}.
## Infeed
See also
-[`ComputationBuilder::Infeed`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Infeed`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
<b> `Infeed(shape)` </b>
@@ -1275,7 +1261,7 @@ See also
Reads a single data item from the implicit Infeed streaming interface of the
device, interpreting the data as the given shape and its layout, and returns a
-`ComputationDataHandle` of the data. Multiple Infeed operations are allowed in a
+`XlaOp` of the data. Multiple Infeed operations are allowed in a
computation, but there must be a total order among the Infeed operations. For
example, two Infeeds in the code below have a total order since there is a
dependency between the while loops.
@@ -1301,21 +1287,19 @@ Infeed of the device.
## Map
See also
-[`ComputationBuilder::Map`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Map`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
<b> `Map(operands..., computation)` </b>
-| Arguments | Type | Semantics |
-| ----------------- | ------------------------ | ----------------------------- |
-| `operands` | sequence of N | N arrays of types T_0..T_{N-1}|
-: : `ComputationDataHandle`s : :
-| `computation` | `Computation` | computation of type `T_0, |
-: : : T_1, ..., T_{N + M -1} -> S` :
-: : : with N parameters of type T :
-: : : and M of arbitrary type :
-| `dimensions` | `int64` array | array of map dimensions |
-| `static_operands` | sequence of M | M arrays of arbitrary type |
-: : `ComputationDataHandle`s : :
+| Arguments | Type | Semantics |
+| ----------------- | ---------------------- | ------------------------------ |
+| `operands` | sequence of N `XlaOp`s | N arrays of types T_0..T_{N-1} |
+| `computation` | `XlaComputation` | computation of type `T_0, T_1, |
+: : : ..., T_{N + M -1} -> S` with N :
+: : : parameters of type T and M of :
+: : : arbitrary type :
+| `dimensions` | `int64` array | array of map dimensions |
+| `static_operands` | sequence of M `XlaOp`s | M arrays of arbitrary type |
Applies a scalar function over the given `operands` arrays, producing an array
of the same dimensions where each element is the result of the mapped function
@@ -1334,18 +1318,18 @@ input arrays to produce the output array.
## Pad
See also
-[`ComputationBuilder::Pad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Pad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
<b> `Pad(operand, padding_value, padding_config)` </b>
-| Arguments | Type | Semantics |
-| ---------------- | ----------------------- | ----------------------------- |
-| `operand` | `ComputationDataHandle` | array of type `T` |
-| `padding_value` | `ComputationDataHandle` | scalar of type `T` to fill in |
-: : : the added padding :
-| `padding_config` | `PaddingConfig` | padding amount on both edges |
-: : : (low, high) and between the :
-: : : elements of each dimension :
+| Arguments | Type | Semantics |
+| ---------------- | --------------- | --------------------------------------- |
+| `operand` | `XlaOp` | array of type `T` |
+| `padding_value` | `XlaOp` | scalar of type `T` to fill in the added |
+: : : padding :
+| `padding_config` | `PaddingConfig` | padding amount on both edges (low, |
+: : : high) and between the elements of each :
+: : : dimension :
Expands the given `operand` array by padding around the array as well as between
the elements of the array with the given `padding_value`. `padding_config`
@@ -1373,7 +1357,7 @@ are all 0. The figure below shows examples of different `edge_padding` and
## Recv
See also
-[`ComputationBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
<b> `Recv(shape, channel_handle)` </b>
@@ -1384,7 +1368,7 @@ See also
Receives data of the given shape from a `Send` instruction in another
computation that shares the same channel handle. Returns a
-ComputationDataHandle for the received data.
+XlaOp for the received data.
The client API of `Recv` operation represents synchronous communication.
However, the instruction is internally decomposed into 2 HLO instructions
@@ -1407,19 +1391,18 @@ complete and returns the received data.
## Reduce
See also
-[`ComputationBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Applies a reduction function to an array.
<b> `Reduce(operand, init_value, computation, dimensions)` </b>
-| Arguments | Type | Semantics |
-| ------------- | ----------------------- | -------------------------------- |
-| `operand` | `ComputationDataHandle` | array of type `T` |
-| `init_value` | `ComputationDataHandle` | scalar of type `T` |
-| `computation` | `Computation` | computation of type `T, T -> T` |
-| `dimensions` | `int64` array | unordered array of dimensions to |
-: : : reduce :
+Arguments | Type | Semantics
+------------- | ---------------- | ---------------------------------------
+`operand` | `XlaOp` | array of type `T`
+`init_value` | `XlaOp` | scalar of type `T`
+`computation` | `XlaComputation` | computation of type `T, T -> T`
+`dimensions` | `int64` array | unordered array of dimensions to reduce
This operation reduces one or more dimensions of the input array into scalars.
The rank of the returned array is `rank(operand) - len(dimensions)`.
@@ -1525,7 +1508,7 @@ Reducing the 3D array over all its dimensions produces the scalar `84`.
## ReducePrecision
See also
-[`ComputationBuilder::ReducePrecision`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::ReducePrecision`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Models the effect of converting floating-point values to a lower-precision
format (such as IEEE-FP16) and back to the original format. The number of
@@ -1535,14 +1518,11 @@ implementations.
<b> `ReducePrecision(operand, mantissa_bits, exponent_bits)` </b>
-| Arguments | Type | Semantics |
-| ------------------- | ----------------------- | ---------------------------- |
-| `operand` | `ComputationDataHandle` | array of floating-point type |
-: : : `T`. :
-| `exponent_bits` | `int32` | number of exponent bits in |
-: : : lower-precision format :
-| `mantissa_bits` | `int32` | number of mantissa bits in |
-: : : lower-precision format :
+Arguments | Type | Semantics
+--------------- | ------- | -------------------------------------------------
+`operand` | `XlaOp` | array of floating-point type `T`.
+`exponent_bits` | `int32` | number of exponent bits in lower-precision format
+`mantissa_bits` | `int32` | number of mantissa bits in lower-precision format
The result is an array of type `T`. The input values are rounded to the nearest
value representable with the given number of mantissa bits (using "ties to even"
@@ -1559,7 +1539,7 @@ portion of the conversion is then simply a no-op.
## ReduceWindow
See also
-[`ComputationBuilder::ReduceWindow`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::ReduceWindow`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Applies a reduction function to all elements in each window of the input
multi-dimensional array, producing an output multi-dimensional array with the
@@ -1571,25 +1551,25 @@ on the left-hand side.
<b> `ReduceWindow(operand, init_value, computation, window_dimensions,
window_strides, padding)` </b>
-| Arguments | Type | Semantics |
-| ------------------- | ----------------------- | ---------------------------- |
-| `operand` | `ComputationDataHandle` | N dimensional array |
-: : : containing elements of type :
-: : : T. This is the base area on :
-: : : which the window is placed. :
-| `init_value` | `ComputationDataHandle` | Starting value for the |
-: : : reduction. See [Reduce] :
-: : : (#reduce) for details. :
-| `computation` | `Computation` | Reduction function of type |
-: : : `T, T -> T`, to apply to all :
-: : : elements in each window :
-| `window_dimensions` | `ArraySlice<int64>` | array of integers for window |
-: : : dimension values :
-| `window_strides` | `ArraySlice<int64>` | array of integers for window |
-: : : stride values :
-| `padding` | `Padding` | padding type for window |
-: : : (Padding\:\:kSame or :
-: : : Padding\:\:kValid) :
+| Arguments | Type | Semantics |
+| ------------------- | ------------------- | -------------------------------- |
+| `operand` | `XlaOp` | N dimensional array containing |
+: : : elements of type T. This is the :
+: : : base area on which the window is :
+: : : placed. :
+| `init_value` | `XlaOp` | Starting value for the |
+: : : reduction. See [Reduce](#reduce) :
+: : : for details. :
+| `computation` | `XlaComputation` | Reduction function of type `T, T |
+: : : -> T`, to apply to all elements :
+: : : in each window :
+| `window_dimensions` | `ArraySlice<int64>` | array of integers for window |
+: : : dimension values :
+| `window_strides` | `ArraySlice<int64>` | array of integers for window |
+: : : stride values :
+| `padding` | `Padding` | padding type for window |
+: : : (Padding\:\:kSame or :
+: : : Padding\:\:kValid) :
Below code and figure shows an example of using `ReduceWindow`. Input is a
matrix of size [4x6] and both window_dimensions and window_stride_dimensions are
@@ -1597,9 +1577,9 @@ matrix of size [4x6] and both window_dimensions and window_stride_dimensions are
```
// Create a computation for the reduction (maximum).
-Computation max;
+XlaComputation max;
{
- ComputationBuilder builder(client_, "max");
+ XlaBuilder builder(client_, "max");
auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
builder.Max(y, x);
@@ -1607,7 +1587,7 @@ Computation max;
}
// Create a ReduceWindow computation with the max reduction computation.
-ComputationBuilder builder(client_, "reduce_window_2x3");
+XlaBuilder builder(client_, "reduce_window_2x3");
auto shape = ShapeUtil::MakeShape(F32, {4, 6});
auto input = builder.Parameter(0, shape, "input");
builder.ReduceWindow(
@@ -1642,7 +1622,7 @@ context of [`Reduce`](#reduce) for more details.
## Reshape
See also
-[`ComputationBuilder::Reshape`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h)
+[`XlaBuilder::Reshape`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h)
and the [`Collapse`](#collapse) operation.
Reshapes the dimensions of an array into a new configuration.
@@ -1650,11 +1630,11 @@ Reshapes the dimensions of an array into a new configuration.
<b> `Reshape(operand, new_sizes)` </b>
<b> `Reshape(operand, dimensions, new_sizes)` </b>
-Arguments | Type | Semantics
------------- | ----------------------- | ---------------------------------------
-`operand` | `ComputationDataHandle` | array of type T
-`dimensions` | `int64` vector | order in which dimensions are collapsed
-`new_sizes` | `int64` vector | vector of sizes of new dimensions
+Arguments | Type | Semantics
+------------ | -------------- | ---------------------------------------
+`operand` | `XlaOp` | array of type T
+`dimensions` | `int64` vector | order in which dimensions are collapsed
+`new_sizes` | `int64` vector | vector of sizes of new dimensions
Conceptually, reshape first flattens an array into a one-dimensional vector of
data values, and then refines this vector into a new shape. The input arguments
@@ -1723,14 +1703,14 @@ Reshape(5, {}, {1,1}) == f32[1x1] {{5}};
## Rev (reverse)
See also
-[`ComputationBuilder::Rev`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Rev`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
<b>`Rev(operand, dimensions)`</b>
-Arguments | Type | Semantics
------------- | ----------------------- | ---------------------
-`operand` | `ComputationDataHandle` | array of type T
-`dimensions` | `ArraySlice<int64>` | dimensions to reverse
+Arguments | Type | Semantics
+------------ | ------------------- | ---------------------
+`operand` | `XlaOp` | array of type T
+`dimensions` | `ArraySlice<int64>` | dimensions to reverse
Reverses the order of elements in the `operand` array along the specified
`dimensions`, generating an output array of the same shape. Each element of the
@@ -1745,7 +1725,7 @@ the two window dimensions during the gradient computation in neural networks.
## RngNormal
See also
-[`ComputationBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Constructs an output of a given shape with random numbers generated following
the $$N(\mu, \sigma)$$ normal distribution. The parameters `mu` and `sigma`, and
@@ -1754,18 +1734,18 @@ be scalar valued.
<b>`RngNormal(mean, sigma, shape)`</b>
-| Arguments | Type | Semantics |
-| --------- | ----------------------- | -------------------------------------- |
-| `mu` | `ComputationDataHandle` | Scalar of type F32 specifying mean of |
-: : : generated numbers :
-| `sigma` | `ComputationDataHandle` | Scalar of type F32 specifying standard |
-: : : deviation of generated numbers :
-| `shape` | `Shape` | Output shape of type F32 |
+| Arguments | Type | Semantics |
+| --------- | ------- | --------------------------------------------------- |
+| `mu` | `XlaOp` | Scalar of type F32 specifying mean of generated |
+: : : numbers :
+| `sigma` | `XlaOp` | Scalar of type F32 specifying standard deviation of |
+: : : generated numbers :
+| `shape` | `Shape` | Output shape of type F32 |
## RngUniform
See also
-[`ComputationBuilder::RngUniform`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::RngUniform`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Constructs an output of a given shape with random numbers generated following
the uniform distribution over the interval $$[a,b)$$. The parameters and output
@@ -1777,27 +1757,27 @@ is implementation-defined.
| Arguments | Type | Semantics |
| --------- | ----------------------- | --------------------------------- |
-| `a` | `ComputationDataHandle` | Scalar of type T specifying lower |
+| `a` | `XlaOp` | Scalar of type T specifying lower |
: : : limit of interval :
-| `b` | `ComputationDataHandle` | Scalar of type T specifying upper |
+| `b` | `XlaOp` | Scalar of type T specifying upper |
: : : limit of interval :
| `shape` | `Shape` | Output shape of type T |
## Select
See also
-[`ComputationBuilder::Select`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Select`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Constructs an output array from elements of two input arrays, based on the
values of a predicate array.
<b> `Select(pred, on_true, on_false)` </b>
-Arguments | Type | Semantics
----------- | ----------------------- | ------------------
-`pred` | `ComputationDataHandle` | array of type PRED
-`on_true` | `ComputationDataHandle` | array of type T
-`on_false` | `ComputationDataHandle` | array of type T
+Arguments | Type | Semantics
+---------- | ------- | ------------------
+`pred` | `XlaOp` | array of type PRED
+`on_true` | `XlaOp` | array of type T
+`on_false` | `XlaOp` | array of type T
The arrays `on_true` and `on_false` must have the same shape. This is also the
shape of the output array. The array `pred` must have the same dimensionality as
@@ -1837,7 +1817,7 @@ the same shape!) then `pred` has to be a scalar of type `PRED`.
## SelectAndScatter
See also
-[`ComputationBuilder::SelectAndScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::SelectAndScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
This operation can be considered as a composite operation that first computes
`ReduceWindow` on the `operand` array to select an element from each window, and
@@ -1870,33 +1850,32 @@ backpropagate the gradient values for a pooling layer in a neural network.
<b>`SelectAndScatter(operand, select, window_dimensions, window_strides,
padding, source, init_value, scatter)`</b>
-| Arguments | Type | Semantics |
-| ------------------- | ----------------------- | ---------------------------- |
-| `operand` | `ComputationDataHandle` | array of type T over which |
-: : : the windows slide :
-| `select` | `Computation` | binary computation of type |
-: : : `T, T -> PRED`, to apply to :
-: : : all elements in each window; :
-: : : returns `true` if the first :
-: : : parameter is selected and :
-: : : returns `false` if the :
-: : : second parameter is selected :
-| `window_dimensions` | `ArraySlice<int64>` | array of integers for window |
-: : : dimension values :
-| `window_strides` | `ArraySlice<int64>` | array of integers for window |
-: : : stride values :
-| `padding` | `Padding` | padding type for window |
-: : : (Padding\:\:kSame or :
-: : : Padding\:\:kValid) :
-| `source` | `ComputationDataHandle` | array of type T with the |
-: : : values to scatter :
-| `init_value` | `ComputationDataHandle` | scalar value of type T for |
-: : : the initial value of the :
-: : : output array :
-| `scatter` | `Computation` | binary computation of type |
-: : : `T, T -> T`, to apply each :
-: : : scatter source element with :
-: : : its destination element :
+| Arguments | Type | Semantics |
+| ------------------- | ------------------- | -------------------------------- |
+| `operand` | `XlaOp` | array of type T over which the |
+: : : windows slide :
+| `select` | `XlaComputation` | binary computation of type `T, T |
+: : : -> PRED`, to apply to all :
+: : : elements in each window; returns :
+: : : `true` if the first parameter is :
+: : : selected and returns `false` if :
+: : : the second parameter is selected :
+| `window_dimensions` | `ArraySlice<int64>` | array of integers for window |
+: : : dimension values :
+| `window_strides` | `ArraySlice<int64>` | array of integers for window |
+: : : stride values :
+| `padding` | `Padding` | padding type for window |
+: : : (Padding\:\:kSame or :
+: : : Padding\:\:kValid) :
+| `source` | `XlaOp` | array of type T with the values |
+: : : to scatter :
+| `init_value` | `XlaOp` | scalar value of type T for the |
+: : : initial value of the output :
+: : : array :
+| `scatter` | `XlaComputation` | binary computation of type `T, T |
+: : : -> T`, to apply each scatter :
+: : : source element with its :
+: : : destination element :
The figure below shows examples of using `SelectAndScatter`, with the `select`
function computing the maximal value among its parameters. Note that when the
@@ -1918,14 +1897,14 @@ context of [`Reduce`](#reduce) for more details.
## Send
See also
-[`ComputationBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
<b> `Send(operand, channel_handle)` </b>
-| Arguments | Type | Semantics |
-| ---------------- | ----------------------- | -------------------------------- |
-| `operand` | `ComputationDataHandle` | data to send (array of type T) |
-| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair |
+Arguments | Type | Semantics
+---------------- | --------------- | -----------------------------------------
+`operand` | `XlaOp` | data to send (array of type T)
+`channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair
Sends the given operand data to a `Recv` instruction in another computation
that shares the same channel handle. Does not return any data.
@@ -1973,7 +1952,7 @@ computations. For example, below schedules lead to deadlocks.
## Slice
See also
-[`ComputationBuilder::Slice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Slice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Slicing extracts a sub-array from the input array. The sub-array is of the same
rank as the input and contains the values inside a bounding box within the input
@@ -1982,23 +1961,20 @@ arguments to the slice operation.
<b> `Slice(operand, start_indices, limit_indices)` </b>
-| Arguments | Type | Semantics |
-| --------------- | ----------------------- | -------------------------------- |
-| `operand` | `ComputationDataHandle` | N dimensional array of type T |
-| `start_indices` | `ArraySlice<int64>` | List of N integers containing |
-: : : the starting indices of the :
-: : : slice for each dimension. Values :
-: : : must be greater than or equal to :
-: : : zero. :
-| `limit_indices` | `ArraySlice<int64>` | List of N integers containing |
-: : : the ending indices (exclusive) :
-: : : for the slice for each :
-: : : dimension. Each value must be :
-: : : strictly greater than the :
-: : : respective `start_indices` value :
-: : : for the dimension and less than :
-: : : or equal to the size of the :
-: : : dimension. :
+| Arguments | Type | Semantics |
+| --------------- | ------------------- | ------------------------------------ |
+| `operand` | `XlaOp` | N dimensional array of type T |
+| `start_indices` | `ArraySlice<int64>` | List of N integers containing the |
+: : : starting indices of the slice for :
+: : : each dimension. Values must be :
+: : : greater than or equal to zero. :
+| `limit_indices` | `ArraySlice<int64>` | List of N integers containing the |
+: : : ending indices (exclusive) for the :
+: : : slice for each dimension. Each value :
+: : : must be strictly greater than the :
+: : : respective `start_indices` value for :
+: : : the dimension and less than or equal :
+: : : to the size of the dimension. :
1-dimensional example:
@@ -2025,15 +2001,15 @@ Slice(b, {2, 1}, {4, 3}) produces:
## Sort
See also
-[`ComputationBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
Sorts the elements in the operand.
<b>`Sort(operand)`</b>
-Arguments | Type | Semantics
---------- | ----------------------- | -------------------
-`operand` | `ComputationDataHandle` | The operand to sort
+Arguments | Type | Semantics
+--------- | ------- | -------------------
+`operand` | `XlaOp` | The operand to sort
## Transpose
@@ -2041,10 +2017,10 @@ See also the @{tf.reshape} operation.
<b>`Transpose(operand)`</b>
-Arguments | Type | Semantics
---------- | ----------------------- | -------------------------
-`operand` | `ComputationDataHandle` | The operand to transpose.
-`permutation` | `ArraySlice<int64>` | How to permute the dimensions.
+Arguments | Type | Semantics
+------------- | ------------------- | ------------------------------
+`operand` | `XlaOp` | The operand to transpose.
+`permutation` | `ArraySlice<int64>` | How to permute the dimensions.
Permutes the operand dimensions with the given permutation, so
@@ -2056,7 +2032,7 @@ This is the same as Reshape(operand, permutation,
## Tuple
See also
-[`ComputationBuilder::Tuple`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::Tuple`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
A tuple containing a variable number of data handles, each of which has its own
shape.
@@ -2075,18 +2051,19 @@ Tuples can be deconstructed (accessed) via the [`GetTupleElement`]
## While
See also
-[`ComputationBuilder::While`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+[`XlaBuilder::While`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
<b> `While(condition, body, init)` </b>
-| Arguments | Type | Semantics |
-| ----------- | ------------- | ---------------------------------------------- |
-| `condition` | `Computation` | Computation of type `T -> PRED` which defines |
-: : : the termination condition of the loop. :
-| `body` | `Computation` | Computation of type `T -> T` which defines the |
-: : : body of the loop. :
-| `init` | `T` | Initial value for the parameter of `condition` |
-: : : and `body`. :
+| Arguments | Type | Semantics |
+| ----------- | ---------------- | ---------------------------------------- |
+| `condition` | `XlaComputation` | XlaComputation of type `T -> PRED` which |
+: : : defines the termination condition of the :
+: : : loop. :
+| `body` | `XlaComputation` | XlaComputation of type `T -> T` which |
+: : : defines the body of the loop. :
+| `init` | `T` | Initial value for the parameter of |
+: : : `condition` and `body`. :
Sequentially executes the `body` until the `condition` fails. This is similar to
a typical while loop in many other languages except for the differences and
diff --git a/tensorflow/docs_src/programmers_guide/eager.md b/tensorflow/docs_src/programmers_guide/eager.md
index 5926e9f7f4..9719858e88 100644
--- a/tensorflow/docs_src/programmers_guide/eager.md
+++ b/tensorflow/docs_src/programmers_guide/eager.md
@@ -120,11 +120,11 @@ def fizzbuzz(max_num):
counter = tf.constant(0)
for num in range(max_num):
num = tf.constant(num)
- if num % 3 == 0 and num % 5 == 0:
+ if int(num % 3) == 0 and int(num % 5) == 0:
print('FizzBuzz')
- elif num % 3 == 0:
+ elif int(num % 3) == 0:
print('Fizz')
- elif num % 5 == 0:
+ elif int(num % 5) == 0:
print('Buzz')
else:
print(num)
diff --git a/tensorflow/docs_src/tutorials/audio_recognition.md b/tensorflow/docs_src/tutorials/audio_recognition.md
index 372ab47df7..d7a8da6f96 100644
--- a/tensorflow/docs_src/tutorials/audio_recognition.md
+++ b/tensorflow/docs_src/tutorials/audio_recognition.md
@@ -25,13 +25,15 @@ python tensorflow/examples/speech_commands/train.py
```
The script will start off by downloading the [Speech Commands
-dataset](https://storage.cloud.google.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz),
-which consists of 65,000 WAVE audio files of people saying thirty different
-words. This data was collected by Google and released under a CC BY license, and
-you can help improve it by [contributing five minutes of your own
+dataset](https://storage.cloud.google.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz),
+which consists of over 105,000 WAVE audio files of people saying thirty
+different words. This data was collected by Google and released under a CC BY
+license, and you can help improve it by [contributing five minutes of your own
voice](https://aiyprojects.withgoogle.com/open_speech_recording). The archive is
-over 1GB, so this part may take a while, but you should see progress logs, and
-once it's been downloaded once you won't need to do this step again.
+over 2GB, so this part may take a while, but you should see progress logs, and
+once it's been downloaded once you won't need to do this step again. You can
+find more information about this dataset in this
+[Speech Commands paper](https://arxiv.org/abs/1804.03209).
Once the downloading has completed, you'll see logging information that looks
like this:
@@ -229,7 +231,7 @@ You can also build this application yourself, since it's open source and
[available as part of the TensorFlow repository on
github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#building-in-android-studio-using-the-tensorflow-aar-from-jcenter).
By default it downloads [a pretrained model from
-tensorflow.org](http://download.tensorflow.org/models/speech_commands_v0.01.zip),
+tensorflow.org](http://download.tensorflow.org/models/speech_commands_v0.02.zip),
but you can easily [replace it with a model you've trained
yourself](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-model-files-optional).
If you do this, you'll need to make sure that the constants in [the main
diff --git a/tensorflow/examples/learn/text_classification_cnn.py b/tensorflow/examples/learn/text_classification_cnn.py
index 9e21aee87f..a40a9eaecb 100644
--- a/tensorflow/examples/learn/text_classification_cnn.py
+++ b/tensorflow/examples/learn/text_classification_cnn.py
@@ -73,7 +73,7 @@ def cnn_model(features, labels, mode):
kernel_size=FILTER_SHAPE2,
padding='VALID')
# Max across each filter to get useful features for classification.
- pool2 = tf.squeeze(tf.reduce_max(conv2, 1), squeeze_dims=[1])
+ pool2 = tf.squeeze(tf.reduce_max(conv2, 1), axis=[1])
# Apply regular WX + B and classification.
logits = tf.layers.dense(pool2, MAX_LABEL, activation=None)
diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py
index f084931215..fc28eb0631 100644
--- a/tensorflow/examples/speech_commands/train.py
+++ b/tensorflow/examples/speech_commands/train.py
@@ -288,7 +288,7 @@ if __name__ == '__main__':
'--data_url',
type=str,
# pylint: disable=line-too-long
- default='http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz',
+ default='http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz',
# pylint: enable=line-too-long
help='Location of speech training data archive on the web.')
parser.add_argument(
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 70a271bd2e..36db3dda6b 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -2610,6 +2610,70 @@ func Reverse(scope *Scope, tensor tf.Output, dims tf.Output) (output tf.Output)
return op.Output(0)
}
+// Copy a tensor setting everything outside a central band in each innermost matrix
+//
+// to zero.
+//
+// The `band` part is computed as follows:
+// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
+// tensor with the same shape where
+//
+// `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
+//
+// The indicator function
+//
+// `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
+// (num_upper < 0 || (n-m) <= num_upper)`.
+//
+// For example:
+//
+// ```
+// # if 'input' is [[ 0, 1, 2, 3]
+// [-1, 0, 1, 2]
+// [-2, -1, 0, 1]
+// [-3, -2, -1, 0]],
+//
+// tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3]
+// [-1, 0, 1, 2]
+// [ 0, -1, 0, 1]
+// [ 0, 0, -1, 0]],
+//
+// tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0]
+// [-1, 0, 1, 0]
+// [-2, -1, 0, 1]
+// [ 0, -2, -1, 0]]
+// ```
+//
+// Useful special cases:
+//
+// ```
+// tf.matrix_band_part(input, 0, -1) ==> Upper triangular part.
+// tf.matrix_band_part(input, -1, 0) ==> Lower triangular part.
+// tf.matrix_band_part(input, 0, 0) ==> Diagonal.
+// ```
+//
+// Arguments:
+// input: Rank `k` tensor.
+// num_lower: 0-D tensor. Number of subdiagonals to keep. If negative, keep entire
+// lower triangle.
+// num_upper: 0-D tensor. Number of superdiagonals to keep. If negative, keep
+// entire upper triangle.
+//
+// Returns Rank `k` tensor of the same shape as input. The extracted banded tensor.
+func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_upper tf.Output) (band tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "MatrixBandPart",
+ Input: []tf.Input{
+ input, num_lower, num_upper,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Clips tensor values to a specified min and max.
//
// Given a tensor `t`, this operation returns a tensor of the same type and
@@ -7246,70 +7310,6 @@ func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...
return op.Output(0)
}
-// Copy a tensor setting everything outside a central band in each innermost matrix
-//
-// to zero.
-//
-// The `band` part is computed as follows:
-// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
-// tensor with the same shape where
-//
-// `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
-//
-// The indicator function
-//
-// `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
-// (num_upper < 0 || (n-m) <= num_upper)`.
-//
-// For example:
-//
-// ```
-// # if 'input' is [[ 0, 1, 2, 3]
-// [-1, 0, 1, 2]
-// [-2, -1, 0, 1]
-// [-3, -2, -1, 0]],
-//
-// tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3]
-// [-1, 0, 1, 2]
-// [ 0, -1, 0, 1]
-// [ 0, 0, -1, 0]],
-//
-// tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0]
-// [-1, 0, 1, 0]
-// [-2, -1, 0, 1]
-// [ 0, -2, -1, 0]]
-// ```
-//
-// Useful special cases:
-//
-// ```
-// tf.matrix_band_part(input, 0, -1) ==> Upper triangular part.
-// tf.matrix_band_part(input, -1, 0) ==> Lower triangular part.
-// tf.matrix_band_part(input, 0, 0) ==> Diagonal.
-// ```
-//
-// Arguments:
-// input: Rank `k` tensor.
-// num_lower: 0-D tensor. Number of subdiagonals to keep. If negative, keep entire
-// lower triangle.
-// num_upper: 0-D tensor. Number of superdiagonals to keep. If negative, keep
-// entire upper triangle.
-//
-// Returns Rank `k` tensor of the same shape as input. The extracted banded tensor.
-func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_upper tf.Output) (band tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "MatrixBandPart",
- Input: []tf.Input{
- input, num_lower, num_upper,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter.
type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index a865e8ca75..d804578070 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -502,7 +502,10 @@ py_test(
cc_library(
name = "python_op_gen",
- srcs = ["framework/python_op_gen.cc"],
+ srcs = [
+ "framework/python_op_gen.cc",
+ "framework/python_op_gen_internal.cc",
+ ],
hdrs = [
"framework/python_op_gen.h",
"framework/python_op_gen_internal.h",
@@ -524,12 +527,12 @@ cc_library(
srcs = ["framework/python_op_gen_main.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":python_op_gen",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/python/eager:python_eager_op_gen",
],
)
@@ -624,6 +627,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":pywrap_tensorflow",
+ "//tensorflow/core:protos_all_py",
],
)
@@ -3033,9 +3037,12 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":array_ops",
+ ":constant_op",
+ ":control_flow_ops",
":dtypes",
":io_ops_gen",
":ops",
+ ":saveable_object",
":util",
"//tensorflow/python/eager:context",
],
@@ -3221,6 +3228,18 @@ py_test(
)
py_test(
+ name = "util_serialization_test",
+ size = "small",
+ srcs = ["util/serialization_test.py"],
+ main = "util/serialization_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client_testlib",
+ ":util",
+ ],
+)
+
+py_test(
name = "future_api_test",
size = "small",
srcs = ["util/future_api_test.py"],
@@ -3232,6 +3251,16 @@ py_test(
)
py_test(
+ name = "function_utils_test",
+ srcs = ["util/function_utils_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client_testlib",
+ ":util",
+ ],
+)
+
+py_test(
name = "tf_contextlib_test",
size = "small",
srcs = ["util/tf_contextlib_test.py"],
@@ -3526,7 +3555,6 @@ tf_py_wrap_cc(
"//tensorflow/core/profiler/internal:print_model_analysis",
"//tensorflow/tools/graph_transforms:transform_graph_lib",
"//tensorflow/python/eager:pywrap_tfe_lib",
- "//tensorflow/python/eager:python_eager_op_gen",
"//util/python:python_headers",
] + (tf_additional_lib_deps() +
tf_additional_plugin_deps() +
@@ -3941,7 +3969,18 @@ cuda_py_test(
":math_ops",
"//tensorflow/core:protos_all_py",
],
- tags = ["noguitar"],
+)
+
+py_test(
+ name = "c_api_util_test",
+ size = "small",
+ srcs = ["framework/c_api_util_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":c_api_util",
+ ":framework_test_lib",
+ ":platform_test",
+ ],
)
py_test(
@@ -4217,7 +4256,7 @@ tf_py_test(
py_test(
name = "basic_session_run_hooks_test",
- size = "small",
+ size = "medium",
srcs = ["training/basic_session_run_hooks_test.py"],
srcs_version = "PY2AND3",
tags = [
diff --git a/tensorflow/python/client/virtual_gpu_test.py b/tensorflow/python/client/virtual_gpu_test.py
index addf63474c..ae653e03dd 100644
--- a/tensorflow/python/client/virtual_gpu_test.py
+++ b/tensorflow/python/client/virtual_gpu_test.py
@@ -236,7 +236,7 @@ class VirtualGpuTest(test_util.TensorFlowTestCase):
with self.test_session(config=self._util.config) as sess:
if not test.is_gpu_available(cuda_only=True):
self.skipTest('No GPU available')
- for _ in range(10):
+ for _ in range(5):
if not self._util.TestRandomGraph(sess):
return
diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
index 6aabad2f57..9fcdf1b062 100644
--- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
@@ -357,6 +357,65 @@ class DatasetConstructorTest(test.TestCase):
# iterator terminates (and the generator iterator is deleted).
self.assertTrue(event.is_set())
+ def testFromGeneratorWithArgs(self):
+
+ def flat_map_fn(elem):
+
+ def generator_with_arg(n):
+ for _ in range(n):
+ yield np.array(n, dtype=np.int64)
+
+ return dataset_ops.Dataset.from_generator(
+ generator_with_arg, output_types=dtypes.int64, output_shapes=(),
+ args=(elem,))
+
+ iterator = (dataset_ops.Dataset
+ .range(5)
+ .flat_map(flat_map_fn)
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
+ for x in expected:
+ self.assertEqual(x, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testFromGeneratorWithTwoArgs(self):
+
+ def flat_map_fn(elem, message):
+
+ def generator_with_arg(n, msg):
+ for i in range(n):
+ yield i, msg
+
+ return dataset_ops.Dataset.from_generator(
+ generator_with_arg, output_types=(dtypes.int64, dtypes.string),
+ output_shapes=((), ()), args=(elem, message))
+
+ iterator = (
+ dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.range(5),
+ dataset_ops.Dataset.from_tensors("Hi!").repeat(None)))
+ .flat_map(flat_map_fn)
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ expected = [(0, b"Hi!"),
+ (0, b"Hi!"), (1, b"Hi!"),
+ (0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"),
+ (0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (3, b"Hi!")]
+ for x in expected:
+ self.assertEqual(x, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
def testGeneratorDatasetFinalizeFunctionCalled(self):
# NOTE(mrry): This test tests the internal `_GeneratorDataset`,
# which affords more control over what the finalize function can do than
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index bd9686f692..8b3c2facbc 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
import abc
-import collections
import threading
import numpy as np
@@ -259,25 +258,32 @@ class Dataset(object):
self._generator = generator
self._lock = threading.Lock()
self._next_id = 0 # GUARDED_BY(self._lock)
- self._iterators = collections.defaultdict(lambda: iter(generator()))
+ self._args = {}
+ self._iterators = {}
- def get_next_id(self):
+ def get_next_id(self, *args):
with self._lock:
ret = self._next_id
self._next_id += 1
+ self._args[ret] = args
# NOTE(mrry): Explicitly create an array of `np.int64` because implicit
# casting in `py_func()` will create an array of `np.int32` on Windows,
# leading to a runtime error.
return np.array(ret, dtype=np.int64)
def get_iterator(self, iterator_id):
- return self._iterators[iterator_id]
+ try:
+ return self._iterators[iterator_id]
+ except KeyError:
+ iterator = iter(self._generator(*self._args.pop(iterator_id)))
+ self._iterators[iterator_id] = iterator
+ return iterator
def iterator_completed(self, iterator_id):
del self._iterators[iterator_id]
@staticmethod
- def from_generator(generator, output_types, output_shapes=None):
+ def from_generator(generator, output_types, output_shapes=None, args=None):
"""Creates a `Dataset` whose elements are generated by `generator`.
The `generator` argument must be a callable object that returns
@@ -320,13 +326,17 @@ class Dataset(object):
`Dataset.from_generator()`.
Args:
- generator: A callable object that takes no arguments and returns an
- object that supports the `iter()` protocol.
+ generator: A callable object that returns an object that supports the
+ `iter()` protocol. If `args` is not specified, `generator` must take
+ no arguments; otherwise it must take as many arguments as there are
+ values in `args`.
output_types: A nested structure of `tf.DType` objects corresponding to
each component of an element yielded by `generator`.
output_shapes: (Optional.) A nested structure of `tf.TensorShape`
objects corresponding to each component of an element yielded by
`generator`.
+ args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated
+ and passed to `generator` as NumPy-array arguments.
Returns:
Dataset: A `Dataset`.
@@ -339,6 +349,10 @@ class Dataset(object):
else:
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
+ if args is None:
+ args = ()
+ else:
+ args = tuple(ops.convert_n_to_tensor(args, name="args"))
flattened_types = nest.flatten(output_types)
flattened_shapes = nest.flatten(output_shapes)
@@ -359,7 +373,7 @@ class Dataset(object):
`generator_state`.
"""
return script_ops.py_func(
- generator_state.get_next_id, [], dtypes.int64, stateful=True)
+ generator_state.get_next_id, args, dtypes.int64, stateful=True)
def generator_next_fn(iterator_id_t):
"""Generates the next element from iterator with ID `iterator_id_t`.
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 0c76afd29d..fd164277b6 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -52,6 +52,9 @@ GET_NEXT_CALL_WARNING_MESSAGE = (
"`next_element` as the input to some computation that is invoked inside "
"the loop.")
+# Collection of all IteratorResources in the `Graph`.
+GLOBAL_ITERATORS = "iterators"
+
@tf_export("data.Iterator")
class Iterator(object):
@@ -75,8 +78,7 @@ class Iterator(object):
output_shapes: A nested structure of `tf.TensorShape` objects
corresponding to each component of an element of this dataset.
output_classes: A nested structure of Python `type` object corresponding
- to each
- component of an element of this iterator.
+ to each component of an element of this iterator.
"""
self._iterator_resource = iterator_resource
self._initializer = initializer
@@ -86,6 +88,7 @@ class Iterator(object):
self._string_handle = gen_dataset_ops.iterator_to_string_handle(
self._iterator_resource)
self._get_next_call_count = 0
+ ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
@staticmethod
def from_structure(output_types,
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index b3268c9047..5530193d4e 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -25,6 +25,7 @@ cc_library(
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:tape",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/python:ndarray_tensor",
"//tensorflow/python:ndarray_tensor_bridge",
"//tensorflow/python:numpy_lib",
@@ -191,22 +192,6 @@ py_library(
],
)
-cc_library(
- name = "python_eager_op_gen",
- srcs = ["python_eager_op_gen.cc"],
- hdrs = ["python_eager_op_gen.h"],
- visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:op_gen_lib",
- "//tensorflow/core:proto_text",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/python:python_op_gen",
- ],
-)
-
py_library(
name = "graph_only_ops",
srcs = ["graph_only_ops.py"],
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index be674487f1..73dbbedbe9 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -96,6 +96,18 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(grads_and_vars[0][0], 1.0)
self.assertAllEqual(id(grads_and_vars[0][1]), id(x))
+ def testWhereGradient(self):
+ # Note: where is special because only some of its arguments are of
+ # differentiable dtypes.
+
+ def f(x):
+ return array_ops.where(x < 10, x, x * x)
+
+ g = backprop.gradients_function(f)
+
+ self.assertAllEqual(g(5.)[0], 1.0)
+ self.assertAllEqual(g(50.)[0], 100.0)
+
def testTwoTargets(self):
with backprop.GradientTape() as t:
x = constant_op.constant(3.0)
diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc
deleted file mode 100644
index 9afab0077b..0000000000
--- a/tensorflow/python/eager/python_eager_op_gen.cc
+++ /dev/null
@@ -1,1047 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/python/eager/python_eager_op_gen.h"
-
-#include <stdio.h>
-#include <sstream>
-#include <unordered_map>
-#include "tensorflow/core/framework/api_def.pb.h"
-#include "tensorflow/core/framework/attr_value.pb.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_def.pb_text.h"
-#include "tensorflow/core/framework/op_def.pb.h"
-#include "tensorflow/core/framework/op_def_util.h"
-#include "tensorflow/core/framework/op_gen_lib.h"
-#include "tensorflow/core/framework/tensor.pb_text.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/gtl/stl_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/python/framework/python_op_gen_internal.h"
-
-namespace tensorflow {
-namespace {
-
-const int kRightMargin = 78;
-
-constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
-
-string AttrVarName(const string& attr_name,
- std::unordered_map<string, string>* attr_expressions) {
- const string var = strings::StrCat("_attr_", attr_name);
- if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
- return var;
-}
-
-void AddInferredAttr(const string& indentation, const string& attr_name,
- const string& value_expression, string* result,
- std::unordered_map<string, string>* attr_expressions) {
- strings::StrAppend(result, indentation,
- AttrVarName(attr_name, attr_expressions), " = ",
- value_expression, "\n");
-}
-
-string VectorToTuple(const std::vector<string>& l) {
- if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
- string ret = "(";
- for (int i = 0; i < l.size(); ++i) {
- if (i > 0) {
- strings::StrAppend(&ret, ", ");
- }
- strings::StrAppend(&ret, l[i]);
- }
- strings::StrAppend(&ret, ")");
- return ret;
-}
-
-void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
- const string& var, string* result) {
- for (int i = 0; i < output_sizes.size(); ++i) {
- if (!output_sizes[i].empty()) {
- strings::StrAppend(result, prefix, var, " = ");
- if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
- if (i + 1 < output_sizes.size()) {
- // Special case i == 0 to avoid "0 +" in the generated code.
- if (i == 0) {
- strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
- var, "[", output_sizes[i], ":]");
- } else {
- strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
- output_sizes[i], "]] + ", var, "[", i, " + ",
- output_sizes[i], ":]");
- }
- } else {
- strings::StrAppend(result, "[", var, "[", i, ":]]");
- }
- strings::StrAppend(result, "\n");
- }
- }
-}
-
-string TensorPBString(const TensorProto& pb) {
- // Note: This gets used in the argument list, and so must survive naive
- // word wrapping.
- return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
-}
-
-const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
- for (int i = 0; i < api_def.in_arg_size(); ++i) {
- if (api_def.in_arg(i).name() == name) {
- return &api_def.in_arg(i);
- }
- }
- return nullptr;
-}
-
-class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
- public:
- GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
- const string& function_name)
- : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) {
- op_name_ = function_name_;
- str_util::ConsumePrefix(&op_name_, "_");
- }
- ~GenEagerPythonOp() override {}
-
- string Code() override;
-
- protected:
- void HandleGraphMode(const string& function_setup);
-
- string GetEagerNotAllowedError();
- void ExpectListArg(const string& indentation, const string& arg_name,
- string* output);
- bool GetEagerFunctionSetup(const string& indentation, string* function_setup);
- void GetOutputSizesAndNumOutputsExpr(std::vector<string>* output_sizes,
- string* num_outputs_expr);
-
- void AddEagerFunctionTeardown(const string& indentation,
- const std::vector<string>& output_sizes,
- bool execute_record_gradient);
-
- bool AddEagerFastPathAndGraphCode(const string& parameters,
- const std::vector<string>& output_sizes,
- const string& eager_not_allowed_error);
- bool AddEagerFallbackCode(const string& parameters,
- const std::vector<string>& output_sizes,
- const string& num_outputs_expr,
- const string& eager_not_allowed_error);
- void AddEagerFastPathExecute();
-
- void AddEagerInferredAttrs(const string& indentation);
- void AddEagerInputCasts(const string& indentation);
- void AddEagerAttrs(const string& indentation);
- void AddEagerExecute(const string& indentation,
- const string& num_outputs_expr);
-
- void AddAttrForArg(const string& attr, int arg_index) {
- gtl::InsertIfNotPresent(&inferred_attrs_, attr,
- op_def_.input_arg(arg_index).name());
- auto iter = attr_to_args_.find(attr);
- if (iter == attr_to_args_.end()) {
- attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
- } else {
- iter->second.push_back(arg_index);
- }
- }
-
- // Returns a string expression representing a flattened list of all
- // the inputs given by `*input_indices` (or all inputs if
- // `input_indices` is nullptr). `*output_sizes` can be used to unflatten.
- string FlattenInputs(const std::vector<int>* input_indices,
- std::vector<string>* output_sizes) const;
-
- StringPiece op_name_;
- typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
- AttrToArgMap attr_to_args_;
- std::unordered_map<string, string> attr_expressions_;
- // This has all the input args followed by those attrs that don't have
- // defaults.
- std::vector<python_op_gen_internal::ParamNames> params_no_default_;
- // The parameters with defaults (these have to be listed after those without).
- // No input args are included, just attrs.
- std::vector<std::pair<python_op_gen_internal::ParamNames, string>>
- params_with_default_;
-};
-
-string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
- const string& function_name) {
- return GenEagerPythonOp(op_def, api_def, function_name).Code();
-}
-
-string GenEagerPythonOp::FlattenInputs(
- const std::vector<int>* input_indices,
- std::vector<string>* output_sizes) const {
- string inputs;
- enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
- const int n = input_indices != nullptr ? input_indices->size()
- : op_def_.input_arg_size();
- for (int j = 0; j < n; ++j) {
- const int i = input_indices ? (*input_indices)[j] : j;
- const auto& arg(op_def_.input_arg(i));
- const bool is_list =
- !arg.type_list_attr().empty() || !arg.number_attr().empty();
- if (is_list) {
- if (inputs_state == WAS_SOLO_INPUT) {
- strings::StrAppend(&inputs, "] + ");
- } else if (inputs_state == WAS_LIST_INPUT) {
- strings::StrAppend(&inputs, " + ");
- }
- strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")");
- inputs_state = WAS_LIST_INPUT;
- if (output_sizes != nullptr) {
- if (!arg.number_attr().empty()) {
- output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
- } else {
- output_sizes->emplace_back(
- strings::StrCat("len(", param_names_[i].GetRenameTo(), ")"));
- }
- }
- } else {
- if (inputs_state == WAS_SOLO_INPUT) {
- strings::StrAppend(&inputs, ", ");
- } else if (inputs_state == WAS_LIST_INPUT) {
- strings::StrAppend(&inputs, " + [");
- } else {
- strings::StrAppend(&inputs, "[");
- }
- strings::StrAppend(&inputs, param_names_[i].GetRenameTo());
- inputs_state = WAS_SOLO_INPUT;
- if (output_sizes != nullptr) output_sizes->emplace_back();
- }
- }
- if (inputs_state == STARTING) return "[]";
- if (inputs_state == WAS_SOLO_INPUT) {
- strings::StrAppend(&inputs, "]");
- }
- return inputs;
-}
-
-string GenEagerPythonOp::Code() {
- if (api_def_.visibility() == ApiDef::SKIP) {
- return "";
- }
-
- for (int i = 0; i < api_def_.arg_order_size(); ++i) {
- const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
- const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
- params_no_default_.emplace_back(api_def_arg.name(),
- api_def_arg.rename_to());
- if (!arg.type_attr().empty()) {
- AddAttrForArg(arg.type_attr(), i);
- } else if (!arg.type_list_attr().empty()) {
- AddAttrForArg(arg.type_list_attr(), i);
- }
- if (!arg.number_attr().empty()) {
- AddAttrForArg(arg.number_attr(), i);
- }
- }
- for (int i = 0; i < op_def_.attr_size(); ++i) {
- const auto& attr(op_def_.attr(i));
- const auto& api_def_attr(api_def_.attr(i));
- // Do not add inferred attrs to the Python function signature.
- if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
- if (api_def_attr.has_default_value()) {
- if (attr.type() == "tensor") {
- params_with_default_.emplace_back(
- python_op_gen_internal::ParamNames(api_def_attr.name(),
- api_def_attr.rename_to()),
- strings::StrCat(
- "_execute.make_tensor(",
- TensorPBString(api_def_attr.default_value().tensor()), ", \"",
- api_def_attr.rename_to(), "\")"));
- } else if (attr.type() == "list(tensor)") {
- std::vector<string> pbtxt;
- for (const auto& pb : api_def_attr.default_value().list().tensor()) {
- pbtxt.emplace_back(TensorPBString(pb));
- }
- params_with_default_.emplace_back(
- python_op_gen_internal::ParamNames(api_def_attr.name(),
- api_def_attr.rename_to()),
- strings::StrCat("[_execute.make_tensor(_pb, \"",
- api_def_attr.rename_to(), "\") for _pb in ",
- VectorToTuple(pbtxt), "]"));
- } else {
- params_with_default_.emplace_back(
- python_op_gen_internal::ParamNames(api_def_attr.name(),
- api_def_attr.rename_to()),
- python_op_gen_internal::AttrValueToPython(
- attr.type(), api_def_attr.default_value(), "_dtypes."));
- }
- } else {
- params_no_default_.emplace_back(api_def_attr.name(),
- api_def_attr.rename_to());
- }
- }
- }
-
- // Save the list of attr parameters (attrs that won't be inferred),
- // those with defaults go at the end.
- // Get the attrs in the order we want by taking the attrs without defaults
- // from the end of params_no_default_, and adding params_no_default_.
- attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() +
- params_with_default_.size());
- for (int i = op_def_.input_arg_size(); i < params_no_default_.size(); ++i) {
- attrs_.push_back(params_no_default_[i].GetName());
- }
- for (const auto& p : params_with_default_) {
- attrs_.push_back(p.first.GetName());
- }
-
- param_names_.reserve(params_no_default_.size() + params_with_default_.size());
- param_names_.insert(param_names_.begin(), params_no_default_.begin(),
- params_no_default_.end());
- for (const auto& param_and_default : params_with_default_) {
- param_names_.push_back(param_and_default.first);
- }
-
- string parameters;
- for (const auto& param : params_no_default_) {
- if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
- strings::StrAppend(&parameters, param.GetRenameTo());
- }
- for (const auto& param_and_default : params_with_default_) {
- if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
- strings::StrAppend(&parameters, param_and_default.first.GetRenameTo(), "=",
- param_and_default.second);
- }
- if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
- strings::StrAppend(&parameters, "name=None");
-
- // Add attr_expressions_ for attrs that are params.
- for (int i = 0; i < attrs_.size(); ++i) {
- const string& attr_name = attrs_[i];
- const string& attr_api_name =
- param_names_[i + op_def_.input_arg_size()].GetRenameTo();
- attr_expressions_[attr_name] = attr_api_name;
- }
- // Add attr_expressions_ for attrs that are inferred.
- for (int i = 0; i < op_def_.attr_size(); ++i) {
- const auto& attr(op_def_.attr(i));
- if (attr.type() == "int") {
- auto arg_list = attr_to_args_.find(attr.name());
- if (arg_list != attr_to_args_.end()) {
- AttrVarName(attr.name(), &attr_expressions_);
- }
- }
- }
-
- string num_outputs_expr;
- std::vector<string> output_sizes(num_outs_);
- GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr);
-
- string eager_not_allowed_error = GetEagerNotAllowedError();
-
- if (!AddEagerFastPathAndGraphCode(parameters, output_sizes,
- eager_not_allowed_error)) {
- return result_;
- }
-
- if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr,
- eager_not_allowed_error)) {
- return result_;
- }
-
- return prelude_ + result_;
-}
-
-void GenEagerPythonOp::HandleGraphMode(const string& function_setup) {
- // Handle graph-mode case
- strings::StrAppend(&result_,
- " _ctx = _context._context\n"
- " if _ctx is None or not _ctx._eager_context.is_eager:\n",
- function_setup,
- " _, _, _op = _op_def_lib._apply_op_helper(\n");
- AddBodyNoReturn(" ");
- if (num_outs_ > 0) {
- strings::StrAppend(&result_, " _result = _op.outputs[:]\n");
- // Special case handling for stateful op with single list output
- // that might be empty.
- if (num_outs_ == 1 && op_def_.is_stateful() &&
- (!op_def_.output_arg(0).number_attr().empty() ||
- !op_def_.output_arg(0).type_list_attr().empty())) {
- // TODO(josh11b): Can skip this if the number_attr/type_list_attr has
- // a constraint indicating that this can never be empty.
- strings::StrAppend(&result_,
- " if not _result:\n"
- " return _op\n");
- }
- strings::StrAppend(&result_, " _inputs_flat = _op.inputs\n");
-
- // Compute graph-mode attrs.
- if (op_def_.attr_size() > 0) {
- string attr_values;
- for (int i = 0; i < op_def_.attr_size(); ++i) {
- if (i > 0) strings::StrAppend(&attr_values, ", ");
- const auto& attr_name(op_def_.attr(i).name());
- strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"",
- attr_name, "\")");
- }
- strings::StrAppend(&attr_values, ")");
- strings::StrAppend(&result_,
- WordWrap(" _attrs = (", attr_values, kRightMargin),
- "\n");
- } else {
- strings::StrAppend(&result_, " _attrs = None\n");
- }
- } else {
- strings::StrAppend(&result_, " return _op\n");
- }
-}
-
-string GenEagerPythonOp::GetEagerNotAllowedError() {
- bool eager_allowed = true;
- string ref_arg;
- for (int i = 0; i < op_def_.input_arg_size(); ++i) {
- const auto& arg = op_def_.input_arg(i);
- if (arg.is_ref()) {
- eager_allowed = false;
- DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name());
- ref_arg = api_def_.in_arg(i).rename_to();
- }
- }
- for (int i = 0; i < op_def_.output_arg_size(); ++i) {
- const auto& arg = op_def_.output_arg(i);
- if (arg.is_ref()) {
- eager_allowed = false;
- DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name());
- ref_arg = api_def_.out_arg(i).rename_to();
- }
- }
-
- if (eager_allowed) return "";
-
- return strings::StrCat("raise RuntimeError(\"", op_name_,
- " op does not support eager execution. ", "Arg '",
- ref_arg, "' is a ref.\")\n");
-}
-
-void GenEagerPythonOp::ExpectListArg(const string& indentation,
- const string& arg_name, string* output) {
- strings::StrAppend(output, indentation, "if not isinstance(", arg_name,
- ", (list, tuple)):\n", indentation, " raise TypeError(\n",
- indentation, " \"Expected list for '", arg_name,
- "' argument to \"\n", indentation, " \"'", op_name_,
- "' Op, not %r.\" % ", arg_name, ")\n");
-}
-
-bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation,
- string* function_setup) {
- // Validate list inputs, infer length attrs.
- for (int i = 0; i < op_def_.attr_size(); ++i) {
- const auto& attr(op_def_.attr(i));
- if (attr.type() == "int") {
- auto arg_list = attr_to_args_.find(attr.name());
- if (arg_list != attr_to_args_.end()) {
- // Inferred int attrs are the lengths of inputs. Validate those
- // inputs are lists and have the same length.
- for (auto iter = arg_list->second.begin();
- iter != arg_list->second.end(); ++iter) {
- const string& arg_api_name = param_names_[*iter].GetRenameTo();
- ExpectListArg(indentation, arg_api_name, function_setup);
- if (iter == arg_list->second.begin()) {
- AddInferredAttr(indentation, attr.name(),
- strings::StrCat("len(", arg_api_name, ")"),
- function_setup, &attr_expressions_);
- } else {
- const auto& attr_var = attr_expressions_[attr.name()];
- strings::StrAppend(
- function_setup, indentation, "if len(", arg_api_name,
- ") != ", attr_var, ":\n", indentation, " raise ValueError(\n",
- indentation, " \"List argument '", arg_api_name, "' to '",
- op_name_, "' Op with length %d \"\n", indentation,
- " \"must match length %d of argument '",
- inferred_attrs_[attr.name()], "'.\" %\n", indentation,
- " (len(", arg_api_name, "), ", attr_var, "))\n");
- }
- }
- }
- }
- }
-
- for (int i = 0; i < attrs_.size(); ++i) {
- const string& attr_name = attrs_[i];
- const auto& param = param_names_[i + op_def_.input_arg_size()];
- const auto& attr = *FindAttr(attr_name, op_def_);
- const string& attr_api_name = param.GetRenameTo();
- StringPiece attr_type = attr.type();
- attr_expressions_[attr_name] = attr_api_name;
- const int default_index = i - (attrs_.size() - params_with_default_.size());
- if (default_index >= 0) {
- const string& default_value = params_with_default_[default_index].second;
- strings::StrAppend(function_setup, indentation, "if ", attr_api_name,
- " is None:\n");
- strings::StrAppend(function_setup, indentation, " ", attr_api_name,
- " = ", default_value, "\n");
- }
- if (str_util::StartsWith(attr_type, "list(")) {
- ExpectListArg(indentation, attr_api_name, function_setup);
- }
-
- if (attr_type == "string") {
- strings::StrAppend(function_setup, indentation, attr_api_name,
- " = _execute.make_str(", attr_api_name, ", \"",
- attr_api_name, "\")\n");
- } else if (attr_type == "list(string)") {
- strings::StrAppend(function_setup, indentation, attr_api_name,
- " = [_execute.make_str(_s, \"", attr_api_name,
- "\") for _s in ", attr_api_name, "]\n");
- } else if (attr_type == "int") {
- strings::StrAppend(function_setup, indentation, attr_api_name,
- " = _execute.make_int(", attr_api_name, ", \"",
- attr_api_name, "\")\n");
- } else if (attr_type == "list(int)") {
- strings::StrAppend(function_setup, indentation, attr_api_name,
- " = [_execute.make_int(_i, \"", attr_api_name,
- "\") for _i in ", attr_api_name, "]\n");
- } else if (attr_type == "float") {
- strings::StrAppend(function_setup, indentation, attr_api_name,
- " = _execute.make_float(", attr_api_name, ", \"",
- attr_api_name, "\")\n");
- } else if (attr_type == "list(float)") {
- strings::StrAppend(function_setup, indentation, attr_api_name,
- " = [_execute.make_float(_f, \"", attr_api_name,
- "\") for _f in ", attr_api_name, "]\n");
- } else if (attr_type == "bool") {
- strings::StrAppend(function_setup, indentation, attr_api_name,
- " = _execute.make_bool(", attr_api_name, ", \"",
- attr_api_name, "\")\n");
- } else if (attr_type == "list(bool)") {
- strings::StrAppend(function_setup, indentation, attr_api_name,
- " = [_execute.make_bool(_b, \"", attr_api_name,
- "\") for _b in ", attr_api_name, "]\n");
- } else if (attr_type == "type") {
- strings::StrAppend(function_setup, indentation, attr_api_name,
- " = _execute.make_type(", attr_api_name, ", \"",
- attr_api_name, "\")\n");
- } else if (attr_type == "list(type)") {
- strings::StrAppend(function_setup, indentation, attr_api_name,
- " = [_execute.make_type(_t, \"", attr_api_name,
- "\") for _t in ", attr_api_name, "]\n");
- } else if (attr_type == "shape") {
- strings::StrAppend(function_setup, indentation, attr_api_name,
- " = _execute.make_shape(", attr_api_name, ", \"",
- attr_api_name, "\")\n");
- } else if (attr_type == "list(shape)") {
- strings::StrAppend(function_setup, indentation, attr_api_name,
- " = [_execute.make_shape(_s, \"", attr_api_name,
- "\") for _s in ", attr_api_name, "]\n");
- } else if (attr_type == "tensor") {
- strings::StrAppend(function_setup, indentation, attr_api_name,
- " = _execute.make_tensor(", attr_api_name, ", \"",
- attr_api_name, "\")\n");
- } else if (attr_type == "list(tensor)") {
- strings::StrAppend(function_setup, indentation, attr_api_name,
- " = [_execute.make_tensor(_t, \"", attr_api_name,
- "\") for _t in ", attr_api_name, "]\n");
- } else if (attr_type != "func") {
- *function_setup =
- strings::StrCat("# No definition for ", function_name_,
- " since we don't support attrs with type\n"
- "# '",
- attr_type, "' right now.\n\n");
- return false;
- }
- }
- return true;
-}
-
-// If output i is list output, output_sizes[i] will be set to a
-// string with the python expression that will evaluate to its
-// length. output_sizes[i] is empty for non-list outputs.
-void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr(
- std::vector<string>* output_sizes, string* num_outputs_expr) {
- // Expression representing the number of outputs.
- int num_fixed_outputs = 0;
- for (int i = 0; i < num_outs_; ++i) {
- const auto& arg(op_def_.output_arg(i));
- if (!arg.number_attr().empty()) {
- if (!num_outputs_expr->empty()) {
- strings::StrAppend(num_outputs_expr, " + ");
- }
- (*output_sizes)[i] = attr_expressions_[arg.number_attr()];
- strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
- } else if (!arg.type_list_attr().empty()) {
- if (!num_outputs_expr->empty()) {
- strings::StrAppend(num_outputs_expr, " + ");
- }
- // Have to be careful to use an expression that works in both
- // graph and eager paths here.
- const auto iter = inferred_attrs_.find(arg.type_list_attr());
- if (iter == inferred_attrs_.end()) {
- (*output_sizes)[i] = strings::StrCat(
- "len(", attr_expressions_[arg.type_list_attr()], ")");
- } else {
- (*output_sizes)[i] = strings::StrCat("len(", iter->second, ")");
- }
- strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
- } else {
- ++num_fixed_outputs;
- }
- }
- if (num_fixed_outputs > 0) {
- if (!num_outputs_expr->empty()) {
- strings::StrAppend(num_outputs_expr, " + ");
- }
- strings::StrAppend(num_outputs_expr, num_fixed_outputs);
- } else if (num_outputs_expr->empty()) {
- *num_outputs_expr = "0";
- }
-}
-
-void GenEagerPythonOp::AddEagerFunctionTeardown(
- const string& indentation, const std::vector<string>& output_sizes,
- bool execute_record_gradient) {
- if (num_outs_ > 0) {
- if (execute_record_gradient) {
- strings::StrAppend(&result_, indentation, "_execute.record_gradient(\n",
- " \"", op_def_.name(),
- "\", _inputs_flat, _attrs, _result, name)\n");
- }
- if (num_outs_ == 1 && !output_sizes[0].empty()) {
- // Single list result.
- } else if (num_outs_ == 1) {
- // Execute returns a single-element list which we need to destructure.
- strings::StrAppend(&result_, indentation, "_result, = _result\n");
- } else {
- // Have multiple outputs, so we will need to reformat the return
- // value of execute() to be a list with one entry per op output
- // (that entry will be a list of tensors if that output is of list
- // type).
- // For list outputs, convert the right subrange of _result into a list.
- Unflatten(indentation, output_sizes, "_result", &result_);
- // Convert to a named tuple.
- strings::StrAppend(&result_, indentation, "_result = _", op_def_.name(),
- "Output._make(_result)\n");
- }
- } else {
- strings::StrAppend(&result_, indentation, "_result = None\n");
- }
- strings::StrAppend(&result_, indentation, "return _result\n\n");
-}
-
-bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
- const string& parameters, const std::vector<string>& output_sizes,
- const string& eager_not_allowed_error) {
- AddExport();
- AddDefLine(function_name_, parameters);
- AddDocStringDescription();
- AddDocStringArgs();
- AddDocStringInputs();
- AddDocStringAttrs();
- AddDocStringNameArg();
- AddOutputGlobals(); // Added to prelude_
- AddDocStringOutputs();
- strings::StrAppend(&result_, " \"\"\"\n");
-
- // Handle graph-mode case
- string function_setup;
- if (!GetEagerFunctionSetup(" ", &function_setup)) {
- result_ = function_setup;
- return false;
- }
- HandleGraphMode(function_setup);
- AddEagerFunctionTeardown(" ", output_sizes,
- true /* execute_record_gradient */);
-
- // Handle eager-mode case
- strings::StrAppend(&result_, " else:\n");
-
- if (eager_not_allowed_error.empty()) {
- AddEagerFastPathExecute();
- } else {
- strings::StrAppend(&result_, " ", eager_not_allowed_error);
- }
-
- strings::StrAppend(&result_, "\n\n");
- return true;
-}
-
-bool GenEagerPythonOp::AddEagerFallbackCode(
- const string& parameters, const std::vector<string>& output_sizes,
- const string& num_outputs_expr, const string& eager_not_allowed_error) {
- if (!eager_not_allowed_error.empty()) {
- strings::StrAppend(&result_, " ", eager_not_allowed_error);
- return true;
- }
-
- AddDefLine(strings::StrCat(function_name_, kEagerFallbackSuffix),
- strings::StrCat(parameters, ", ctx=None"));
- strings::StrAppend(
- &result_, " r\"\"\"This is the slowpath function for Eager mode.\n");
- strings::StrAppend(&result_, " This is for function ", function_name_,
- "\n \"\"\"\n");
-
- strings::StrAppend(&result_, " _ctx = ctx if ctx else _context.context()\n");
-
- string function_setup;
- if (!GetEagerFunctionSetup(" ", &function_setup)) {
- result_ = function_setup;
- return false;
- }
- strings::StrAppend(&result_, function_setup);
-
- AddEagerInferredAttrs(" ");
- AddEagerInputCasts(" ");
- strings::StrAppend(
- &result_, " _inputs_flat = ", FlattenInputs(nullptr, nullptr), "\n");
- AddEagerAttrs(" ");
- AddEagerExecute(" ", num_outputs_expr);
-
- AddEagerFunctionTeardown(" ", output_sizes,
- true /* execute_record_gradient */);
-
- return true;
-}
-
-void GenEagerPythonOp::AddEagerFastPathExecute() {
- string fastpath_execute_params = strings::StrCat(
- "_ctx._context_handle, _ctx._eager_context.device_name, \"",
- op_def_.name(), "\", ", "name, _ctx._post_execution_callbacks");
- string fallback_params;
-
- for (int i = 0; i < api_def_.in_arg_size(); i++) {
- const string param_name = param_names_[i].GetRenameTo();
- strings::StrAppend(&fastpath_execute_params, ", ", param_name);
- if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
- strings::StrAppend(&fallback_params, param_name);
- }
-
- for (const auto& attr : api_def_.attr()) {
- if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
- strings::StrAppend(&fastpath_execute_params, ", \"", attr.name(), "\", ",
- attr.rename_to());
-
- if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
- strings::StrAppend(&fallback_params, attr.rename_to(), "=",
- attr.rename_to());
- }
- }
-
- if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
- strings::StrAppend(&fallback_params, "name=name");
-
- strings::StrAppend(&result_, " try:\n");
- strings::StrAppend(
- &result_, " ",
- "_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n",
- WordWrap(strings::StrCat(" "),
- strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
- "\n");
-
- if (op_def_.output_arg_size() > 1) {
- const string output_tuple_name =
- strings::StrCat("_", op_def_.name(), "Output");
- strings::StrAppend(&result_, " ", "_result = ", output_tuple_name,
- "._make(_result)\n");
- }
- strings::StrAppend(&result_, " ", "return _result\n");
-
- // Handle fallback.
- if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
- strings::StrAppend(&fallback_params, "ctx=_ctx");
- strings::StrAppend(&result_, " ", "except _core._FallbackException:\n");
- strings::StrAppend(
- &result_, " ", "return ", function_name_, kEagerFallbackSuffix,
- "(\n",
- WordWrap(strings::StrCat(" "),
- strings::StrCat(fallback_params, ")"), kRightMargin),
- "\n");
-
- // Any errors thrown from execute need to be unwrapped from
- // _NotOkStatusException.
- strings::StrAppend(&result_, " ",
- "except _core._NotOkStatusException as e:\n");
- strings::StrAppend(&result_, " ", "if name is not None:\n");
- strings::StrAppend(&result_, " ",
- "message = e.message + \" name: \" + name\n");
- strings::StrAppend(&result_, " ", "else:\n");
- strings::StrAppend(&result_, " ", "message = e.message\n");
- strings::StrAppend(
- &result_, " ",
- "_six.raise_from(_core._status_to_exception(e.code, message), None)\n");
-}
-
-void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) {
- // Figure out values for inferred attrs, and cast to eager tensors.
- for (int i = 0; i < op_def_.attr_size(); ++i) {
- const auto& attr(op_def_.attr(i));
- const auto& api_def_attr(api_def_.attr(i));
- auto arg_list = attr_to_args_.find(attr.name());
- if (arg_list != attr_to_args_.end()) {
- if (attr.type() == "type") {
- std::vector<string> output_sizes;
- const string flattened =
- FlattenInputs(&arg_list->second, &output_sizes);
- string conversion = strings::StrCat("_execute.args_to_matching_eager(",
- flattened, ", _ctx");
- if (attr.has_default_value()) {
- strings::StrAppend(
- &conversion, ", ",
- python_op_gen_internal::AttrValueToPython(
- attr.type(), api_def_attr.default_value(), "_dtypes."));
- }
- strings::StrAppend(&conversion, ")");
- const string var_name = AttrVarName(attr.name(), &attr_expressions_);
- if (output_sizes.size() == 1) {
- // Avoid creating a temporary variable in the case where
- // we can easily assign to the right value directly.
- const string inputs_var =
- param_names_[arg_list->second.front()].GetRenameTo();
- if (output_sizes.front().empty()) {
- strings::StrAppend(&result_, indentation, var_name, ", (",
- inputs_var, ",) = ", conversion, "\n");
- } else {
- strings::StrAppend(&result_, indentation, var_name, ", ",
- inputs_var, " = ", conversion, "\n");
- }
- } else {
- const string inputs_var = strings::StrCat("_inputs_", attr.name());
- strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
- " = ", conversion, "\n");
- // Convert from a flat list of eager tensors back to the
- // parameter variables.
- Unflatten(indentation, output_sizes, inputs_var, &result_);
- std::vector<string> p;
- for (int j : arg_list->second) {
- p.emplace_back(param_names_[j].GetRenameTo());
- }
- strings::StrAppend(&result_, indentation, VectorToTuple(p), " = ",
- inputs_var, "\n");
- }
- } else if (attr.type() == "list(type)") {
- // NOTE: We ignore default values for these attrs, since it is
- // unclear how you would use it, and the one use case is
- // parse_single_sequence_example which only needs it for
- // backwards compatibility.
- const string var_name = AttrVarName(attr.name(), &attr_expressions_);
- string inputs_var;
- string conversion;
- if (arg_list->second.size() > 1) {
- // If you have more than one list(tensor) argument, their types
- // have to match.
- std::vector<string> lists;
- for (auto iter = arg_list->second.begin();
- iter != arg_list->second.end(); ++iter) {
- lists.push_back(param_names_[*iter].GetRenameTo());
- }
- inputs_var = VectorToTuple(lists);
- conversion = "_execute.args_to_mixed_eager_tensors";
- } else {
- // For one list(tensor) argument, we just convert every
- // element of the list to an eager tensor.
- inputs_var = param_names_[arg_list->second.front()].GetRenameTo();
- conversion = "_execute.convert_to_mixed_eager_tensors";
- }
- strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
- " = ", conversion, "(", inputs_var, ", _ctx)\n");
- }
- }
- }
-}
-
-void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) {
- // Cast remaining args to eager tensors
- for (int i = 0; i < op_def_.input_arg_size(); ++i) {
- const auto& arg(op_def_.input_arg(i));
- if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
- const string& param = param_names_[i].GetRenameTo();
- const string fn = arg.number_attr().empty() ? "" : "n_";
- const string dtype =
- python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
- strings::StrAppend(&result_, indentation, param, " = _ops.convert_", fn,
- "to_tensor(", param, ", ", dtype, ")\n");
- }
-}
-
-void GenEagerPythonOp::AddEagerAttrs(const string& indentation) {
- // Compute eager attrs
- if (op_def_.attr_size() > 0) {
- string attr_values;
- for (int i = 0; i < op_def_.attr_size(); ++i) {
- if (i > 0) strings::StrAppend(&attr_values, ", ");
- const auto& attr_name(op_def_.attr(i).name());
- strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
- attr_expressions_[attr_name]);
- }
- strings::StrAppend(&attr_values, ")");
- strings::StrAppend(
- &result_,
- WordWrap(indentation, strings::StrCat("_attrs = (", attr_values),
- kRightMargin),
- "\n");
- } else {
- strings::StrAppend(&result_, indentation, "_attrs = None\n");
- }
-}
-
-void GenEagerPythonOp::AddEagerExecute(const string& indentation,
- const string& num_outputs_expr) {
- const string return_prefix =
- strings::StrCat(indentation, "_result = _execute.execute(");
- const string return_args = strings::StrCat(
- "b\"", op_def_.name(), "\", ", num_outputs_expr,
- ", inputs=_inputs_flat, attrs=_attrs, ctx=_ctx, name=name)");
- strings::StrAppend(&result_,
- // Wrap the arguments, and indent to the (.
- WordWrap(return_prefix, return_args, kRightMargin), "\n");
-}
-
-string GetEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs,
- const std::vector<string>& hidden_ops,
- bool require_shapes,
- const string& source_file_name = "") {
- string result;
- // Header
- // TODO(josh11b): Mention the library for which wrappers are being generated.
- strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
-
-This file is MACHINE GENERATED! Do not edit.
-)");
-
- // Mention the original source file so someone tracing back through
- // generated Python code will know where to look next.
- if (!source_file_name.empty()) {
- strings::StrAppend(&result, "Original C++ source file: ");
- strings::StrAppend(&result, source_file_name);
- strings::StrAppend(&result, "\n");
- }
-
- strings::StrAppend(&result, R"("""
-
-import collections as _collections
-import six as _six
-
-from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
-from tensorflow.python.eager import context as _context
-from tensorflow.python.eager import core as _core
-from tensorflow.python.eager import execute as _execute
-from tensorflow.python.framework import dtypes as _dtypes
-from tensorflow.python.framework import errors as _errors
-from tensorflow.python.framework import tensor_shape as _tensor_shape
-
-from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
-# Needed to trigger the call to _set_call_cpp_shape_fn.
-from tensorflow.python.framework import common_shapes as _common_shapes
-from tensorflow.python.framework import op_def_registry as _op_def_registry
-from tensorflow.python.framework import ops as _ops
-from tensorflow.python.framework import op_def_library as _op_def_library
-from tensorflow.python.util.tf_export import tf_export
-
-)");
-
- // We'll make a copy of ops that filters out descriptions.
- OpList cleaned_ops;
- auto out = cleaned_ops.mutable_op();
- out->Reserve(ops.op_size());
- for (const auto& op_def : ops.op()) {
- const auto* api_def = api_defs.GetApiDef(op_def.name());
-
- if (api_def->visibility() == ApiDef::SKIP) {
- continue;
- }
- // An op is hidden if either its ApiDef visibility is HIDDEN
- // or it is in the hidden_ops list.
- bool is_hidden = api_def->visibility() == ApiDef::HIDDEN;
- bool hidden_by_api_def = is_hidden;
- if (!is_hidden) {
- for (const string& hidden : hidden_ops) {
- if (op_def.name() == hidden) {
- is_hidden = true;
- break;
- }
- }
- }
-
- string function_name;
- python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
- &function_name);
- bool is_reserved = python_op_gen_internal::IsPythonReserved(function_name);
-
- // Prefix an op with underscore if the op is listed in hidden_ops or
- // name is reserved or it is of the exceptions in IsOpWithUnderscorePrefix.
- // Do not add underscores to ops set to HIDDEN in ApiDef otherwise.
- // TODO(annarev): don't prefix with underscores even if op is in hidden_ops.
- if (is_hidden) {
- if (!hidden_by_api_def || is_reserved ||
- python_op_gen_internal::IsOpWithUnderscorePrefix(function_name)) {
- function_name = strings::StrCat("_", function_name);
- }
- } else if (is_reserved) {
- // When users create custom python wrappers, they may link in the
- // default op registry by accident, and because they can't
- // enumerate all 'hidden' symbols, this guard is to prevent
- // instantiating a python reserved word in their wrapper.
- continue;
- }
-
- strings::StrAppend(&result,
- GetEagerPythonOp(op_def, *api_def, function_name));
-
- if (!require_shapes) {
- strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
- "\")(None)\n\n");
- }
-
- auto added = out->Add();
- *added = op_def;
- RemoveNonDeprecationDescriptionsFromOpDef(added);
- }
-
- result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
- op_list = _op_def_pb2.OpList()
- op_list.ParseFromString(op_list_proto_bytes)
- _op_def_registry.register_op_list(op_list)
- op_def_lib = _op_def_library.OpDefLibrary()
- op_def_lib.add_op_list(op_list)
- return op_def_lib
-)");
-
- result.append("# ");
- auto ops_text = ProtoDebugString(cleaned_ops);
- str_util::StripTrailingWhitespace(&ops_text);
- result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true));
- result.append("\n");
- strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n",
- str_util::CEscape(cleaned_ops.SerializeAsString()).c_str());
- return result;
-}
-
-} // namespace
-
-void PrintEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs,
- const std::vector<string>& hidden_ops,
- bool require_shapes, const string& source_file_name) {
- printf("%s", GetEagerPythonOps(ops, api_defs, hidden_ops, require_shapes,
- source_file_name)
- .c_str());
-}
-
-string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) {
- string op_list_str(op_list_buf, op_list_len);
- OpList ops;
- ops.ParseFromString(op_list_str);
-
- ApiDefMap api_def_map(ops);
- return GetEagerPythonOps(ops, api_def_map, {}, false);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/python/eager/python_eager_op_gen.h b/tensorflow/python/eager/python_eager_op_gen.h
deleted file mode 100644
index d27b00139d..0000000000
--- a/tensorflow/python/eager/python_eager_op_gen.h
+++ /dev/null
@@ -1,43 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
-#define TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
-
-#include <string>
-#include <vector>
-#include "tensorflow/core/framework/op_def.pb.h"
-#include "tensorflow/core/framework/op_gen_lib.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-
-// hidden_ops should be a list of Op names that should get a leading _
-// in the output. Prints the output to stdout.
-// Optional fourth argument is the name of the original C++ source file
-// where the ops' REGISTER_OP() calls reside.
-void PrintEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs,
- const std::vector<string>& hidden_ops,
- bool require_shapes,
- const string& source_file_name = "");
-
-// Get the python wrappers for a list of ops in a OpList.
-// `op_list_buf` should be a pointer to a buffer containing
-// the binary encoded OpList proto, and `op_list_len` should be the
-// length of that buffer.
-string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len);
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h
index 88982b0c85..bc042eb19e 100644
--- a/tensorflow/python/eager/pywrap_tensor.h
+++ b/tensorflow/python/eager/pywrap_tensor.h
@@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/lib/core/numpy.h"
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 2d9a084bc6..d7a9dc5ede 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -21,6 +21,7 @@ py_library(
":export",
":exporter",
":inputs",
+ ":keras",
":linear",
":model_fn",
":parsing_utils",
@@ -445,16 +446,6 @@ py_library(
],
)
-py_test(
- name = "util_test",
- srcs = ["util_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":util",
- "//tensorflow/python:client_testlib",
- ],
-)
-
py_library(
name = "estimator",
srcs = [
@@ -645,7 +636,6 @@ py_library(
":metric_keys",
":model_fn",
":prediction_keys",
- ":util",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:control_flow_ops",
@@ -659,6 +649,7 @@ py_library(
"//tensorflow/python:string_ops",
"//tensorflow/python:summary",
"//tensorflow/python:training",
+ "//tensorflow/python:util",
"//tensorflow/python:weights_broadcast_ops",
"//tensorflow/python/feature_column",
"//tensorflow/python/ops/losses",
@@ -911,3 +902,65 @@ py_test(
"//tensorflow/python:training",
],
)
+
+py_library(
+ name = "keras",
+ srcs = ["keras.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":estimator",
+ ":export_export",
+ ":model_fn",
+ ":run_config",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:nn",
+ "//tensorflow/python:partitioned_variables",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:tensor_util",
+ "//tensorflow/python:training",
+ "//tensorflow/python:training_util",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python/keras:backend",
+ "//tensorflow/python/keras:engine",
+ "//tensorflow/python/keras:layers",
+ "//tensorflow/python/ops/losses",
+ "//tensorflow/python/saved_model",
+ "//tensorflow/python/saved_model:signature_constants",
+ ],
+)
+
+py_test(
+ name = "keras_test",
+ size = "large",
+ srcs = ["keras_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["notsan"],
+ deps = [
+ ":keras",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python/estimator:numpy_io",
+ "//tensorflow/python/estimator:run_config",
+ "//tensorflow/python/keras",
+ "//tensorflow/python/keras:backend",
+ "//tensorflow/python/keras:engine",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index e7fbf8eb72..1feac36f35 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -126,7 +126,8 @@ def _dnn_model_fn(features,
activation_fn=nn.relu,
dropout=None,
input_layer_partitioner=None,
- config=None):
+ config=None,
+ tpu_estimator_spec=False):
"""Deep Neural Net model_fn.
Args:
@@ -147,6 +148,8 @@ def _dnn_model_fn(features,
input_layer_partitioner: Partitioner for input layer. Defaults
to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: `RunConfig` object to configure the runtime settings.
+ tpu_estimator_spec: Whether to return a `_TPUEstimatorSpec` or
+ or `model_fn.EstimatorSpec` instance.
Returns:
An `EstimatorSpec` instance.
@@ -154,59 +157,6 @@ def _dnn_model_fn(features,
Raises:
ValueError: If features has the wrong type.
"""
- tpu_estimator_spec = _tpu_dnn_model_fn(
- features=features,
- labels=labels,
- mode=mode,
- head=head,
- hidden_units=hidden_units,
- feature_columns=feature_columns,
- optimizer=optimizer,
- activation_fn=activation_fn,
- dropout=dropout,
- input_layer_partitioner=input_layer_partitioner,
- config=config)
- return tpu_estimator_spec.as_estimator_spec()
-
-
-def _tpu_dnn_model_fn(features,
- labels,
- mode,
- head,
- hidden_units,
- feature_columns,
- optimizer='Adagrad',
- activation_fn=nn.relu,
- dropout=None,
- input_layer_partitioner=None,
- config=None):
- """Deep Neural Net model_fn for TPUEstimator.
-
- Args:
- features: dict of `Tensor`.
- labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
- dtype `int32` or `int64` in the range `[0, n_classes)`.
- mode: Defines whether this is training, evaluation or prediction.
- See `ModeKeys`.
- head: A `head_lib._Head` instance.
- hidden_units: Iterable of integer number of hidden units per layer.
- feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
- optimizer: String, `tf.Optimizer` object, or callable that creates the
- optimizer to use for training. If not specified, will use the Adagrad
- optimizer with a default learning rate of 0.05.
- activation_fn: Activation function applied to each layer.
- dropout: When not `None`, the probability we will drop out a given
- coordinate.
- input_layer_partitioner: Partitioner for input layer. Defaults
- to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
- config: `RunConfig` object to configure the runtime settings.
-
- Returns:
- A `model_fn.TPUEstimatorSpec` instance.
-
- Raises:
- ValueError: If features has the wrong type.
- """
if not isinstance(features, dict):
raise ValueError('features should be a dictionary of `Tensor`s. '
'Given type: {}'.format(type(features)))
@@ -235,12 +185,20 @@ def _tpu_dnn_model_fn(features,
input_layer_partitioner=input_layer_partitioner)
logits = logit_fn(features=features, mode=mode)
- return head._create_tpu_estimator_spec( # pylint: disable=protected-access
- features=features,
- mode=mode,
- labels=labels,
- optimizer=optimizer,
- logits=logits)
+ if tpu_estimator_spec:
+ return head._create_tpu_estimator_spec( # pylint: disable=protected-access
+ features=features,
+ mode=mode,
+ labels=labels,
+ optimizer=optimizer,
+ logits=logits)
+ else:
+ return head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ optimizer=optimizer,
+ logits=logits)
@tf_export('estimator.DNNClassifier')
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index 232637314d..04fe4d97e4 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -24,7 +24,6 @@ import collections
import six
from tensorflow.python.estimator import model_fn
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export_output
@@ -46,6 +45,7 @@ from tensorflow.python.ops.losses import losses
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.summary import summary
from tensorflow.python.training import training_util
+from tensorflow.python.util import function_utils
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -461,7 +461,7 @@ def _validate_loss_fn_args(loss_fn):
Raises:
ValueError: If the signature is unexpected.
"""
- loss_fn_args = util.fn_args(loss_fn)
+ loss_fn_args = function_utils.fn_args(loss_fn)
for required_arg in ['labels', 'logits']:
if required_arg not in loss_fn_args:
raise ValueError(
@@ -484,7 +484,7 @@ def _call_loss_fn(loss_fn, labels, logits, features, expected_loss_dim=1):
Returns:
Loss Tensor with shape [D0, D1, ... DN, expected_loss_dim].
"""
- loss_fn_args = util.fn_args(loss_fn)
+ loss_fn_args = function_utils.fn_args(loss_fn)
kwargs = {}
if 'features' in loss_fn_args:
kwargs['features'] = features
@@ -1545,26 +1545,6 @@ def _assert_range(labels, n_classes, message=None):
return array_ops.identity(labels)
-# TODO(b/69000400): Delete this method.
-def _weights(features, weight_column):
- """Fetches weights from features."""
- with ops.name_scope(None, 'weights', values=features.values()):
- if weight_column is None:
- return 1.
- if isinstance(weight_column, six.string_types):
- weight_column = feature_column_lib.numeric_column(
- key=weight_column, shape=(1,))
- if not isinstance(weight_column, feature_column_lib._NumericColumn): # pylint: disable=protected-access
- raise TypeError('Weight column must be either a string or _NumericColumn.'
- ' Given type: {}.'.format(type(weight_column)))
- weights = weight_column._get_dense_tensor( # pylint: disable=protected-access
- feature_column_lib._LazyBuilder(features)) # pylint: disable=protected-access
- if not (weights.dtype.is_floating or weights.dtype.is_integer):
- raise ValueError('Weight column should be castable to float. '
- 'Given dtype: {}'.format(weights.dtype))
- return math_ops.to_float(weights, name='weights')
-
-
def _binary_logistic_or_multi_class_head(
n_classes, weight_column, label_vocabulary, loss_reduction):
"""Creates either binary or multi-class head.
diff --git a/tensorflow/python/estimator/canned/metric_keys.py b/tensorflow/python/estimator/canned/metric_keys.py
index f374d31549..4f7c849ba4 100644
--- a/tensorflow/python/estimator/canned/metric_keys.py
+++ b/tensorflow/python/estimator/canned/metric_keys.py
@@ -42,3 +42,8 @@ class MetricKeys(object):
ACCURACY_AT_THRESHOLD = 'accuracy/positive_threshold_%g'
PRECISION_AT_THRESHOLD = 'precision/positive_threshold_%g'
RECALL_AT_THRESHOLD = 'recall/positive_threshold_%g'
+
+ # The following require a class id applied.
+ PROBABILITY_MEAN_AT_CLASS = 'probability_mean/class%d'
+ AUC_AT_CLASS = 'auc/class%d'
+ AUC_PR_AT_CLASS = 'auc_precision_recall/class%d'
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 64457eb1ff..a98600b261 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -36,9 +36,9 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.export import export as export_helpers
from tensorflow.python.estimator.export import export_output
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
@@ -62,6 +62,7 @@ from tensorflow.python.training import training_util
from tensorflow.python.training import warm_starting_util
from tensorflow.python.util import compat
from tensorflow.python.util import compat_internal
+from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -370,6 +371,21 @@ class Estimator(object):
else:
return []
+ def eval_dir(self, name=None):
+ """Shows directory name where evaluation metrics are dumped.
+
+ Args:
+ name: Name of the evaluation if user needs to run multiple evaluations on
+ different data sets, such as on training data vs test data. Metrics for
+ different evaluations are saved in separate folders, and appear
+ separately in tensorboard.
+
+ Returns:
+ A string which is the path of directory contains evaluation metrics.
+ """
+ return os.path.join(self._model_dir, 'eval' if not name else
+ 'eval_' + name)
+
def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
name=None):
"""Evaluates the model given evaluation data input_fn.
@@ -616,29 +632,28 @@ class Estimator(object):
strip_default_attrs=strip_default_attrs,
mode=model_fn_lib.ModeKeys.PREDICT)
- def _export_all_saved_models(
- self, export_dir_base, input_receiver_fn_map,
+ def _export_saved_model_for_mode(
+ self, export_dir_base, input_receiver_fn,
assets_extra=None,
as_text=False,
checkpoint_path=None,
- strip_default_attrs=False):
+ strip_default_attrs=False,
+ mode=model_fn_lib.ModeKeys.PREDICT):
# pylint: disable=line-too-long
- """Exports requested train/eval/predict graphs as separate SavedModels.
+ """Exports a single train/eval/predict graph as a SavedModel.
- This is a wrapper around export_saved_model_for_mode that accepts
- multiple modes simultaneously and creates directories for each under
- export_dir_base. See `Estimator.export_saved_model_for_mode` for
- further details as to how the export works for each mode.
+ This method is a wrapper for _export_all_saved_models, and wraps a raw
+ input_receiver_fn in a dictionary to pass in to that function.
+ See _export_all_saved_models for full docs.
- See tf.contrib.estimator.export_all_saved_models for the currently
+ See tf.contrib.estimator.export_saved_model_for_mode for the currently
exposed version of this function.
Args:
export_dir_base: A string containing a directory in which to create
timestamped subdirectories containing exported SavedModels.
- input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn
- mappings, where the input_receiver_fn is a function that takes no
- argument and returns the appropriate subclass of `InputReceiver`.
+ input_receiver_fn: a function that takes no argument and
+ returns the appropriate subclass of `InputReceiver`.
assets_extra: A dict specifying how to populate the assets.extra directory
within the exported SavedModel, or `None` if no extra assets are needed.
as_text: whether to write the SavedModel proto in text format.
@@ -647,60 +662,53 @@ class Estimator(object):
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ mode: tf.estimator.ModeKeys value indicating with mode will be exported.
Returns:
- A dict of tf.estimator.ModeKeys value to string path for each exported
- directory.
+ The string path to the exported directory.
Raises:
- ValueError: if any input_receiver_fn is None, no export_outputs
+ ValueError: if input_receiver_fn is None, no export_outputs
are provided, or no checkpoint can be found.
"""
# pylint: enable=line-too-long
- # TODO(b/65561022): Consider allowing multiple input_receiver_fns per mode.
- exported = {}
- for mode, input_receiver_fn in input_receiver_fn_map.items():
- export_mode_dir = os.path.join(
- compat.as_bytes(export_dir_base),
- compat.as_bytes(mode))
- gfile.MakeDirs(export_mode_dir)
-
- exported_path = self._export_saved_model_for_mode(
- export_mode_dir,
- input_receiver_fn,
- assets_extra=assets_extra,
- as_text=as_text,
- checkpoint_path=checkpoint_path,
- strip_default_attrs=strip_default_attrs,
- mode=mode)
+ if not input_receiver_fn:
+ raise ValueError('An input_receiver_fn must be defined.')
- exported[mode] = exported_path
+ input_receiver_fn_map = {mode: input_receiver_fn}
- return exported
+ return self._export_all_saved_models(
+ export_dir_base,
+ input_receiver_fn_map,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ checkpoint_path=checkpoint_path,
+ strip_default_attrs=strip_default_attrs)
- def _export_saved_model_for_mode(
- self, export_dir_base, input_receiver_fn,
+ def _export_all_saved_models(
+ self, export_dir_base, input_receiver_fn_map,
assets_extra=None,
as_text=False,
checkpoint_path=None,
- strip_default_attrs=False,
- mode=model_fn_lib.ModeKeys.PREDICT):
+ strip_default_attrs=False):
# pylint: disable=line-too-long
- """Exports a single train/eval/predict graph as a SavedModel.
-
- For a detailed guide, see
- @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}.
+ """Exports a SavedModel containing MetaGraphDefs for each requested mode.
- See tf.contrib.estimator.export_saved_model_for_mode for the currently
+ See tf.contrib.estimator.export_all_saved_models for the currently
exposed version of this function.
- This method takes an input_receiver_fn and mode. For the mode passed in,
+ For each mode passed in via the input_receiver_fn_map,
this method builds a new graph by calling the input_receiver_fn to obtain
feature and label `Tensor`s. Next, this method calls the `Estimator`'s
model_fn in the passed mode to generate the model graph based on
those features and labels, and restores the given checkpoint
(or, lacking that, the most recent checkpoint) into the graph.
- Finally, it creates a timestamped export directory below the
+ Only one of the modes is used for saving variables to the SavedModel
+ (order of preference: TRAIN, EVAL, then PREDICT), such that up to three
+ MetaGraphDefs are saved with a single set of variables in a single
+ SavedModel directory.
+
+ For the variables and MetaGraphDefs, a timestamped export directory below
export_dir_base, and writes a `SavedModel` into it containing
the `MetaGraphDef` for the given mode and its associated signatures.
@@ -727,8 +735,9 @@ class Estimator(object):
Args:
export_dir_base: A string containing a directory in which to create
timestamped subdirectories containing exported SavedModels.
- input_receiver_fn: a function that takes no argument and
- returns the appropriate subclass of `InputReceiver`.
+ input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn
+ mappings, where the input_receiver_fn is a function that takes no
+ argument and returns the appropriate subclass of `InputReceiver`.
assets_extra: A dict specifying how to populate the assets.extra directory
within the exported SavedModel, or `None` if no extra assets are needed.
as_text: whether to write the SavedModel proto in text format.
@@ -737,20 +746,18 @@ class Estimator(object):
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
- mode: tf.estimator.ModeKeys value indicating with mode will be exported.
Returns:
- The string path to the exported directory.
+ A dict of tf.estimator.ModeKeys value to string path for each exported
+ directory.
Raises:
- ValueError: if input_receiver_fn is None, no export_outputs
+ ValueError: if any input_receiver_fn is None, no export_outputs
are provided, or no checkpoint can be found.
"""
# pylint: enable=line-too-long
+ # TODO(b/65561022): Consider allowing multiple input_receiver_fns per mode.
with context.graph_mode():
- if not input_receiver_fn:
- raise ValueError('An input_receiver_fn must be defined.')
-
if not checkpoint_path:
# Locate the latest checkpoint
checkpoint_path = saver.latest_checkpoint(self._model_dir)
@@ -762,9 +769,34 @@ class Estimator(object):
builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
- self._add_meta_graph_and_variables_for_mode(
- builder, input_receiver_fn, checkpoint_path,
- strip_default_attrs, mode)
+ save_variables = True
+ # Note that the order in which we run here matters, as the first
+ # mode we pass through will be used to save the variables. We run TRAIN
+ # first, as that is also the mode used for checkpoints, and therefore
+ # we are not likely to have vars in PREDICT that are not in the checkpoint
+ # created by TRAIN.
+ if input_receiver_fn_map.get(model_fn_lib.ModeKeys.TRAIN):
+ self._add_meta_graph_for_mode(
+ builder, input_receiver_fn_map, checkpoint_path,
+ strip_default_attrs, save_variables,
+ mode=model_fn_lib.ModeKeys.TRAIN)
+ save_variables = False
+ if input_receiver_fn_map.get(model_fn_lib.ModeKeys.EVAL):
+ self._add_meta_graph_for_mode(
+ builder, input_receiver_fn_map, checkpoint_path,
+ strip_default_attrs, save_variables,
+ mode=model_fn_lib.ModeKeys.EVAL)
+ save_variables = False
+ if input_receiver_fn_map.get(model_fn_lib.ModeKeys.PREDICT):
+ self._add_meta_graph_for_mode(
+ builder, input_receiver_fn_map, checkpoint_path,
+ strip_default_attrs, save_variables,
+ mode=model_fn_lib.ModeKeys.PREDICT)
+ save_variables = False
+
+ if save_variables:
+ raise ValueError('No valid modes for exporting found. Got {}.'.format(
+ input_receiver_fn_map.keys()))
builder.save(as_text)
@@ -782,24 +814,31 @@ class Estimator(object):
gfile.Rename(temp_export_dir, export_dir)
return export_dir
- def _add_meta_graph_and_variables_for_mode(
- self, builder, input_receiver_fn, checkpoint_path, strip_default_attrs,
+ def _add_meta_graph_for_mode(
+ self, builder, input_receiver_fn_map, checkpoint_path,
+ strip_default_attrs, save_variables=True,
mode=model_fn_lib.ModeKeys.PREDICT):
# pylint: disable=line-too-long
"""Loads variables and adds them along with a MetaGraphDef for saving.
Args:
builder: instance of SavedModelBuilder that will be used for saving.
- input_receiver_fn: a function that takes no argument and
- returns the appropriate subclass of `InputReceiver`.
+ input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn
+ mappings, where the input_receiver_fn is a function that takes no
+ argument and returns the appropriate subclass of `InputReceiver`.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ save_variables: bool, whether variables should be saved. If False, just
+ the MetaGraphDef will be saved. Note that save_variables should only be
+ True for the first call to this function, and the SavedModelBuilder will
+ raise an error if that is not the case.
mode: tf.estimator.ModeKeys value indicating which mode will be exported.
"""
# pylint: enable=line-too-long
+ input_receiver_fn = input_receiver_fn_map[mode]
with ops.Graph().as_default() as g:
self._create_and_assert_global_step(g)
random_seed.set_random_seed(self._config.tf_random_seed)
@@ -832,15 +871,24 @@ class Estimator(object):
saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
sharded=True)
- saver_for_restore.restore(session, checkpoint_path)
+
+ try:
+ saver_for_restore.restore(session, checkpoint_path)
+ except errors.NotFoundError as e:
+ msg = ('Could not load all requested variables from the checkpoint. '
+ 'Please make sure your model_fn does not expect variables '
+ 'that were not saved in the checkpoint.\n\n'
+ 'Encountered error with mode `{}` while restoring checkpoint '
+ 'from: `{}`. Full Traceback:\n\n{}').format(
+ mode, checkpoint_path, e)
+ raise ValueError(msg)
# We add the train op explicitly for now, so that we don't have to
# change the Builder public interface. Note that this is a no-op
# for prediction, where train_op is None.
builder._add_train_op(estimator_spec.train_op) # pylint: disable=protected-access
- builder.add_meta_graph_and_variables(
- session,
+ meta_graph_kwargs = dict(
tags=export_tags,
signature_def_map=signature_def_map,
assets_collection=ops.get_collection(
@@ -848,6 +896,12 @@ class Estimator(object):
strip_default_attrs=strip_default_attrs,
legacy_init_op=local_init_op)
+ if save_variables:
+ builder.add_meta_graph_and_variables(
+ session, **meta_graph_kwargs)
+ else:
+ builder.add_meta_graph(**meta_graph_kwargs)
+
def _get_export_outputs_for_spec(self, estimator_spec):
"""Given an EstimatorSpec, determine what our export outputs should be.
@@ -998,7 +1052,7 @@ class Estimator(object):
Raises:
ValueError: if input_fn takes invalid arguments.
"""
- input_fn_args = util.fn_args(input_fn)
+ input_fn_args = function_utils.fn_args(input_fn)
kwargs = {}
if 'mode' in input_fn_args:
kwargs['mode'] = mode
@@ -1024,7 +1078,7 @@ class Estimator(object):
Raises:
ValueError: if model_fn returns invalid objects.
"""
- model_fn_args = util.fn_args(self._model_fn)
+ model_fn_args = function_utils.fn_args(self._model_fn)
kwargs = {}
if 'labels' in model_fn_args:
kwargs['labels'] = labels
@@ -1286,10 +1340,6 @@ class Estimator(object):
'initialization to evaluate.'.format(self._model_dir))
checkpoint_path = latest_path
- # Setup output directory.
- eval_dir = os.path.join(self._model_dir, 'eval' if not name else
- 'eval_' + name)
-
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
global_step_tensor = self._create_and_assert_global_step(g)
@@ -1333,7 +1383,7 @@ class Estimator(object):
config=self._session_config)
_write_dict_to_summary(
- output_dir=eval_dir,
+ output_dir=self.eval_dir(name),
dictionary=eval_results,
current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP])
@@ -1433,7 +1483,7 @@ def _get_replica_device_setter(config):
def _verify_model_fn_args(model_fn, params):
"""Verifies model fn arguments."""
- args = set(util.fn_args(model_fn))
+ args = set(function_utils.fn_args(model_fn))
if 'features' not in args:
raise ValueError('model_fn (%s) must include features argument.' % model_fn)
if params is not None and 'params' not in args:
diff --git a/tensorflow/python/estimator/estimator_lib.py b/tensorflow/python/estimator/estimator_lib.py
index 3815f42470..f188f2d4e6 100644
--- a/tensorflow/python/estimator/estimator_lib.py
+++ b/tensorflow/python/estimator/estimator_lib.py
@@ -39,6 +39,7 @@ from tensorflow.python.estimator.exporter import Exporter
from tensorflow.python.estimator.exporter import FinalExporter
from tensorflow.python.estimator.exporter import LatestExporter
from tensorflow.python.estimator.inputs import inputs
+from tensorflow.python.estimator.keras import model_to_estimator
from tensorflow.python.estimator.model_fn import EstimatorSpec
from tensorflow.python.estimator.model_fn import ModeKeys
from tensorflow.python.estimator.run_config import RunConfig
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 02088e5134..1b70189948 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -33,7 +33,6 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.export import export_output
from tensorflow.python.estimator.inputs import numpy_io
@@ -72,6 +71,7 @@ from tensorflow.python.training import saver_test_utils
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
_TMP_DIR = '/tmp'
_ANOTHER_TMP_DIR = '/another_tmp'
@@ -332,7 +332,7 @@ class EstimatorConstructorTest(test.TestCase):
_, _, _, _, _ = features, labels, mode, config, params
est = estimator.Estimator(model_fn=model_fn)
- model_fn_args = util.fn_args(est.model_fn)
+ model_fn_args = function_utils.fn_args(est.model_fn)
self.assertEqual(
set(['features', 'labels', 'mode', 'config']), set(model_fn_args))
@@ -342,7 +342,7 @@ class EstimatorConstructorTest(test.TestCase):
_, _ = features, labels
est = estimator.Estimator(model_fn=model_fn)
- model_fn_args = util.fn_args(est.model_fn)
+ model_fn_args = function_utils.fn_args(est.model_fn)
self.assertEqual(
set(['features', 'labels', 'mode', 'config']), set(model_fn_args))
@@ -1061,6 +1061,15 @@ class EstimatorDatasetIntegrationTest(test.TestCase):
class EstimatorEvaluateTest(test.TestCase):
+ def test_eval_dir(self):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ model_dir='some_path')
+ expected_eval_dir = os.path.join('some_path', 'eval')
+ self.assertEqual(expected_eval_dir, est.eval_dir())
+ expected_eval_dir_name = os.path.join('some_path', 'eval_a_name')
+ self.assertEqual(expected_eval_dir_name, est.eval_dir('a_name'))
+
def test_input_fn_args(self):
expected_mode = model_fn_lib.ModeKeys.EVAL
expected_params = {'batch_size': 10}
@@ -1385,7 +1394,7 @@ class EstimatorEvaluateTest(test.TestCase):
# Get last evaluation Event written.
for key in ['foo/0', 'foo/1', 'foo/2']:
self.assertTrue(
- check_eventfile_for_keyword(key, os.path.join(est.model_dir, 'eval')),
+ check_eventfile_for_keyword(key, est.eval_dir()),
'{} should be part of reported summaries.'.format(key))
@@ -2013,12 +2022,9 @@ class EstimatorExportTest(test.TestCase):
input_receiver_fn_map = {
model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- self.assertEqual(len(export_dirs), 1)
- # Restore, to validate that the export was well-formed.
- export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.SERVING], export_dir)
@@ -2035,12 +2041,9 @@ class EstimatorExportTest(test.TestCase):
input_receiver_fn_map = {
model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- self.assertEqual(len(export_dirs), 1)
- # Restore, to validate that the export was well-formed.
- export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
@@ -2058,12 +2061,9 @@ class EstimatorExportTest(test.TestCase):
input_receiver_fn_map = {
model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- self.assertEqual(len(export_dirs), 1)
- # Restore, to validate that the export was well-formed.
- export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.EVAL], export_dir)
@@ -2082,12 +2082,9 @@ class EstimatorExportTest(test.TestCase):
model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- self.assertEqual(len(export_dirs), 2)
- # Restore, to validate that the export was well-formed.
- export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
@@ -2096,7 +2093,7 @@ class EstimatorExportTest(test.TestCase):
self.assertFalse('eval_multiplied' in graph_ops)
self.assertTrue('feature_x' in graph_ops)
self.assertTrue('weight' in graph_ops)
- export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL]
+
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.EVAL], export_dir)
@@ -2117,12 +2114,11 @@ class EstimatorExportTest(test.TestCase):
model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(),
model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
# Restore, to validate that the export was well-formed.
- for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items():
- export_dir = export_dirs[mode]
+ for tag_set in model_fn_lib.EXPORT_TAG_MAP.values():
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, tag_set, export_dir)
@@ -2139,10 +2135,9 @@ class EstimatorExportTest(test.TestCase):
model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
@@ -2150,7 +2145,6 @@ class EstimatorExportTest(test.TestCase):
self.assertTrue('later_var' in graph_ops)
self.assertTrue('weight' in graph_ops)
- export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.SERVING], export_dir)
@@ -2166,10 +2160,9 @@ class EstimatorExportTest(test.TestCase):
model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
}
- export_dirs, tmpdir = self._test_export_all_saved_models(
+ export_dir, tmpdir = self._test_export_all_saved_models(
input_receiver_fn_map)
- export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
@@ -2179,7 +2172,6 @@ class EstimatorExportTest(test.TestCase):
collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertEqual(3, collection_vars[-1].eval())
- export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT]
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.SERVING], export_dir)
@@ -2207,16 +2199,15 @@ class EstimatorExportTest(test.TestCase):
# Perform the export.
export_dir_base = os.path.join(
compat.as_bytes(tmpdir), compat.as_bytes('export'))
- export_dirs = est._export_all_saved_models(
+ export_dir = est._export_all_saved_models(
export_dir_base, input_receiver_fn_map)
# Check that all the files are in the right places.
self.assertTrue(gfile.Exists(export_dir_base))
- for _, export_dir in export_dirs.items():
- self._validate_exported_files(export_dir)
+ self._validate_exported_files(export_dir)
- return export_dirs, tmpdir
+ return export_dir, tmpdir
def _validate_exported_files(self, export_dir):
self.assertTrue(gfile.Exists(export_dir))
@@ -2233,6 +2224,42 @@ class EstimatorExportTest(test.TestCase):
compat.as_bytes(export_dir),
compat.as_bytes('variables/variables.data-00000-of-00001'))))
+ def test_export_all_saved_models_var_not_found(self):
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
+ }
+
+ def _model_fn_with_predict_only_vars(features, labels, mode):
+ _, _ = features, labels
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ variables.Variable(1., name='only_in_predict')
+ else:
+ variables.Variable(1., name='otherwise')
+
+ prediction = constant_op.constant(1.)
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=prediction,
+ loss=constant_op.constant(1.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1),
+ export_outputs={
+ 'test': export_output.PredictOutput({'prediction': prediction})
+ })
+
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(model_fn=_model_fn_with_predict_only_vars)
+ est.train(input_fn=_x_y_input_fn, steps=1)
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+
+ err_regex = r'Could not load all requested variables[\w\W]*infer'
+ with self.assertRaisesRegexp(ValueError, err_regex):
+ est._export_all_saved_models(export_dir_base, input_receiver_fn_map)
+
def test_export_savedmodel_with_saveables_proto_roundtrip(self):
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(
@@ -2464,6 +2491,43 @@ class EstimatorExportTest(test.TestCase):
self.assertTrue(self.mock_saver.restore.called)
+ def test_scaffold_is_used_for_saver_multiple_modes(self):
+ tmpdir = tempfile.mkdtemp()
+
+ def _model_fn_scaffold(features, labels, mode):
+ _, _ = features, labels
+ variables.Variable(1., name='weight')
+ real_saver = saver.Saver()
+ self.mock_saver = test.mock.Mock(
+ wraps=real_saver, saver_def=real_saver.saver_def)
+ scores = constant_op.constant([3.])
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ scaffold = training.Scaffold(saver=self.mock_saver)
+ else:
+ scaffold = training.Scaffold()
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ predictions=constant_op.constant([[1.]]),
+ loss=constant_op.constant(0.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1),
+ scaffold=scaffold,
+ export_outputs={'test': export_output.ClassificationOutput(scores)})
+
+ est = estimator.Estimator(model_fn=_model_fn_scaffold)
+ est.train(dummy_input_fn, steps=1)
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
+ }
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ est._export_all_saved_models(export_dir_base, input_receiver_fn_map)
+
+ self.assertTrue(self.mock_saver.restore.called)
+
def test_scaffold_is_used_for_local_init(self):
tmpdir = tempfile.mkdtemp()
@@ -2509,6 +2573,61 @@ class EstimatorExportTest(test.TestCase):
my_int_value = sess.run(my_int)
self.assertEqual(12345, my_int_value)
+ def test_scaffold_is_used_for_local_init_multiple_modes(self):
+ tmpdir = tempfile.mkdtemp()
+
+ def _model_fn_scaffold(features, labels, mode):
+ _, _ = features, labels
+ my_int = variables.Variable(1, name='my_int',
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ scores = constant_op.constant([3.])
+ with ops.control_dependencies([
+ variables.local_variables_initializer(),
+ lookup_ops.tables_initializer()
+ ]):
+ assign_op = state_ops.assign(my_int, 12345)
+
+ custom_local_init_op = None
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ # local_initSop must be an Operation, not a Tensor.
+ custom_local_init_op = control_flow_ops.group(assign_op)
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ predictions=constant_op.constant([[1.]]),
+ loss=constant_op.constant(0.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1),
+ scaffold=training.Scaffold(local_init_op=custom_local_init_op),
+ export_outputs={'test': export_output.ClassificationOutput(scores)})
+
+ est = estimator.Estimator(model_fn=_model_fn_scaffold)
+ est.train(dummy_input_fn, steps=1)
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(),
+ model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn()
+ }
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dir = est._export_all_saved_models(
+ export_dir_base, input_receiver_fn_map)
+
+ # Restore, to validate that the custom local_init_op runs.
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.SERVING], export_dir)
+ my_int = graph.get_tensor_by_name('my_int:0')
+ my_int_value = sess.run(my_int)
+ self.assertEqual(12345, my_int_value)
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.TRAINING], export_dir)
+ my_int = graph.get_tensor_by_name('my_int:0')
+ my_int_value = sess.run(my_int)
+ self.assertEqual(1, my_int_value)
+
def test_features_labels_mode(self):
given_features = {'test-features': constant_op.constant([[1], [1]])}
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index 9aafb56679..48ae8cd497 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -14,7 +14,6 @@
# ==============================================================================
"""Configuration and utilities for receiving inputs at serving time."""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -37,7 +36,6 @@ from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
-
_SINGLE_FEATURE_DEFAULT_NAME = 'feature'
_SINGLE_RECEIVER_DEFAULT_NAME = 'input'
_SINGLE_LABEL_DEFAULT_NAME = 'label'
@@ -69,11 +67,11 @@ def _wrap_and_check_receiver_tensors(receiver_tensors):
def _check_tensor(tensor, name, error_label='feature'):
"""Check that passed `tensor` is a Tensor or SparseTensor."""
- if not (isinstance(tensor, ops.Tensor)
- or isinstance(tensor, sparse_tensor.SparseTensor)):
+ if not (isinstance(tensor, ops.Tensor) or
+ isinstance(tensor, sparse_tensor.SparseTensor)):
fmt_name = ' {}'.format(name) if name else ''
- value_error = ValueError(
- '{}{} must be a Tensor or SparseTensor.'.format(error_label, fmt_name))
+ value_error = ValueError('{}{} must be a Tensor or SparseTensor.'.format(
+ error_label, fmt_name))
# NOTE(ericmc): This if-else block is a specific carve-out for
# LabeledTensor, which has a `.tensor` attribute and which is
# convertible to tf.Tensor via ops.convert_to_tensor.
@@ -92,19 +90,23 @@ def _check_tensor(tensor, name, error_label='feature'):
def _check_tensor_key(name, error_label='feature'):
if not isinstance(name, six.string_types):
- raise ValueError(
- '{} keys must be strings: {}.'.format(error_label, name))
+ raise ValueError('{} keys must be strings: {}.'.format(error_label, name))
@tf_export('estimator.export.ServingInputReceiver')
-class ServingInputReceiver(collections.namedtuple(
- 'ServingInputReceiver',
- ['features', 'receiver_tensors', 'receiver_tensors_alternatives'])):
+class ServingInputReceiver(
+ collections.namedtuple(
+ 'ServingInputReceiver',
+ ['features', 'receiver_tensors', 'receiver_tensors_alternatives'])):
"""A return type for a serving_input_receiver_fn.
The expected return values are:
features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
- `SparseTensor`, specifying the features to be passed to the model.
+ `SparseTensor`, specifying the features to be passed to the model. Note:
+ if `features` passed is not a dict, it will be wrapped in a dict with a
+ single entry, using 'feature' as the key. Consequently, the model must
+ accept a feature dict of the form {'feature': tensor}. You may use
+ `TensorServingInputReceiver` if you want the tensor to be passed as is.
receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor`
or `SparseTensor`, specifying input nodes where this receiver expects to
be fed by default. Typically, this is a single placeholder expecting
@@ -119,7 +121,9 @@ class ServingInputReceiver(collections.namedtuple(
Defaults to None.
"""
- def __new__(cls, features, receiver_tensors,
+ def __new__(cls,
+ features,
+ receiver_tensors,
receiver_tensors_alternatives=None):
if features is None:
raise ValueError('features must be defined.')
@@ -139,8 +143,9 @@ class ServingInputReceiver(collections.namedtuple(
for alternative_name, receiver_tensors_alt in (
six.iteritems(receiver_tensors_alternatives)):
if not isinstance(receiver_tensors_alt, dict):
- receiver_tensors_alt = {_SINGLE_RECEIVER_DEFAULT_NAME:
- receiver_tensors_alt}
+ receiver_tensors_alt = {
+ _SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors_alt
+ }
# Updating dict during iteration is OK in this case.
receiver_tensors_alternatives[alternative_name] = (
receiver_tensors_alt)
@@ -157,9 +162,10 @@ class ServingInputReceiver(collections.namedtuple(
@tf_export('estimator.export.TensorServingInputReceiver')
-class TensorServingInputReceiver(collections.namedtuple(
- 'TensorServingInputReceiver',
- ['features', 'receiver_tensors', 'receiver_tensors_alternatives'])):
+class TensorServingInputReceiver(
+ collections.namedtuple(
+ 'TensorServingInputReceiver',
+ ['features', 'receiver_tensors', 'receiver_tensors_alternatives'])):
"""A return type for a serving_input_receiver_fn.
This is for use with models that expect a single `Tensor` or `SparseTensor`
@@ -194,7 +200,9 @@ class TensorServingInputReceiver(collections.namedtuple(
Defaults to None.
"""
- def __new__(cls, features, receiver_tensors,
+ def __new__(cls,
+ features,
+ receiver_tensors,
receiver_tensors_alternatives=None):
if features is None:
raise ValueError('features must be defined.')
@@ -212,9 +220,9 @@ class TensorServingInputReceiver(collections.namedtuple(
receiver_tensors_alternatives=receiver.receiver_tensors_alternatives)
-class SupervisedInputReceiver(collections.namedtuple(
- 'SupervisedInputReceiver',
- ['features', 'labels', 'receiver_tensors'])):
+class SupervisedInputReceiver(
+ collections.namedtuple('SupervisedInputReceiver',
+ ['features', 'labels', 'receiver_tensors'])):
"""A return type for a training_input_receiver_fn or eval_input_receiver_fn.
This differs from a ServingInputReceiver in that (1) this receiver expects
@@ -272,11 +280,13 @@ def build_parsing_serving_input_receiver_fn(feature_spec,
Returns:
A serving_input_receiver_fn suitable for use in serving.
"""
+
def serving_input_receiver_fn():
"""An input_fn that expects a serialized tf.Example."""
- serialized_tf_example = array_ops.placeholder(dtype=dtypes.string,
- shape=[default_batch_size],
- name='input_example_tensor')
+ serialized_tf_example = array_ops.placeholder(
+ dtype=dtypes.string,
+ shape=[default_batch_size],
+ name='input_example_tensor')
receiver_tensors = {'examples': serialized_tf_example}
features = parsing_ops.parse_example(serialized_tf_example, feature_spec)
return ServingInputReceiver(features, receiver_tensors)
@@ -295,10 +305,12 @@ def _placeholder_from_tensor(t, default_batch_size=None):
return array_ops.placeholder(dtype=t.dtype, shape=shape, name=t.op.name)
-def _placeholders_from_receiver_tensors_dict(
- input_vals, default_batch_size=None):
- return {name: _placeholder_from_tensor(t, default_batch_size)
- for name, t in input_vals.items()}
+def _placeholders_from_receiver_tensors_dict(input_vals,
+ default_batch_size=None):
+ return {
+ name: _placeholder_from_tensor(t, default_batch_size)
+ for name, t in input_vals.items()
+ }
@tf_export('estimator.export.build_raw_serving_input_receiver_fn')
@@ -316,6 +328,7 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None):
Returns:
A serving_input_receiver_fn.
"""
+
def serving_input_receiver_fn():
"""A serving_input_receiver_fn that expects features to be fed directly."""
receiver_tensors = _placeholders_from_receiver_tensors_dict(
@@ -329,8 +342,9 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None):
return serving_input_receiver_fn
-def build_raw_supervised_input_receiver_fn(
- features, labels, default_batch_size=None):
+def build_raw_supervised_input_receiver_fn(features,
+ labels,
+ default_batch_size=None):
"""Build a supervised_input_receiver_fn for raw features and labels.
This function wraps tensor placeholders in a supervised_receiver_fn
@@ -443,11 +457,12 @@ def build_all_signature_defs(receiver_tensors,
for receiver_name, receiver_tensors_alt in (
six.iteritems(receiver_tensors_alternatives)):
if not isinstance(receiver_tensors_alt, dict):
- receiver_tensors_alt = {_SINGLE_RECEIVER_DEFAULT_NAME:
- receiver_tensors_alt}
+ receiver_tensors_alt = {
+ _SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors_alt
+ }
for output_key, export_output in export_outputs.items():
- signature_name = '{}:{}'.format(receiver_name or 'None',
- output_key or 'None')
+ signature_name = '{}:{}'.format(receiver_name or 'None', output_key or
+ 'None')
try:
signature = export_output.as_signature_def(receiver_tensors_alt)
signature_def_map[signature_name] = signature
@@ -464,8 +479,11 @@ def build_all_signature_defs(receiver_tensors,
# signatures produced for serving. We skip this check for training and eval
# signatures, which are not intended for serving.
if serving_only:
- signature_def_map = {k: v for k, v in signature_def_map.items()
- if signature_def_utils.is_valid_signature(v)}
+ signature_def_map = {
+ k: v
+ for k, v in signature_def_map.items()
+ if signature_def_utils.is_valid_signature(v)
+ }
return signature_def_map
@@ -506,8 +524,8 @@ def _log_signature_report(signature_def_map, excluded_signatures):
if not signature_def_map:
logging.warn('Export includes no signatures!')
- elif (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
- not in signature_def_map):
+ elif (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in
+ signature_def_map):
logging.warn('Export includes no default signature!')
@@ -547,6 +565,5 @@ def get_temp_export_dir(timestamped_export_dir):
"""
(dirname, basename) = os.path.split(timestamped_export_dir)
temp_export_dir = os.path.join(
- compat.as_bytes(dirname),
- compat.as_bytes('temp-{}'.format(basename)))
+ compat.as_bytes(dirname), compat.as_bytes('temp-{}'.format(basename)))
return temp_export_dir
diff --git a/tensorflow/python/keras/_impl/keras/estimator.py b/tensorflow/python/estimator/keras.py
index 5c79c964c8..5c79c964c8 100644
--- a/tensorflow/python/keras/_impl/keras/estimator.py
+++ b/tensorflow/python/estimator/keras.py
diff --git a/tensorflow/python/keras/_impl/keras/estimator_test.py b/tensorflow/python/estimator/keras_test.py
index 80fa87d041..a89f7f7db3 100644
--- a/tensorflow/python/keras/_impl/keras/estimator_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -25,6 +25,7 @@ import tempfile
import numpy as np
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.estimator import keras as keras_lib
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import ops
@@ -192,7 +193,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
metrics=['mse', keras.metrics.categorical_accuracy])
with self.test_session():
- est_keras = keras.estimator.model_to_estimator(
+ est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
before_eval_results = est_keras.evaluate(
input_fn=eval_input_fn, steps=1)
@@ -214,7 +215,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
metrics=['mse', keras.metrics.categorical_accuracy])
with self.test_session():
- est_keras = keras.estimator.model_to_estimator(
+ est_keras = keras_lib.model_to_estimator(
keras_model=keras_model,
# Also use dict config argument to get test coverage for that line.
config={
@@ -240,7 +241,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
metrics=['mse', keras.metrics.categorical_accuracy])
with self.test_session():
- est_keras = keras.estimator.model_to_estimator(
+ est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
before_eval_results = est_keras.evaluate(
@@ -264,7 +265,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
np.random.random((10, _NUM_CLASS)))
original_preds = keras_model.predict(np.ones((10,) + _INPUT_SIZE))
- est_keras = keras.estimator.model_to_estimator(
+ est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
before_eval_results = est_keras.evaluate(
@@ -300,7 +301,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_eval = keras_model.evaluate(x_test, y_test, batch_size=32)
with self.test_session():
- keras_est = keras.estimator.model_to_estimator(
+ keras_est = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
est_eval = keras_est.evaluate(input_fn=eval_input_fn)
@@ -336,7 +337,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_pred = [np.argmax(y) for y in keras_model.predict(x_test)]
with self.test_session():
- keras_est = keras.estimator.model_to_estimator(
+ keras_est = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
est_pred = [
np.argmax(y[keras_model.output_names[0]])
@@ -383,7 +384,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
with self.test_session():
model = multi_inputs_multi_outputs_model()
- est_keras = keras.estimator.model_to_estimator(
+ est_keras = keras_lib.model_to_estimator(
keras_model=model, config=self._config)
before_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
@@ -409,7 +410,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras.models.save_model(keras_model, fname)
with self.test_session():
- keras_est = keras.estimator.model_to_estimator(
+ keras_est = keras_lib.model_to_estimator(
keras_model_path=fname, config=self._config)
est_pred = [
np.argmax(y[keras_model.output_names[0]])
@@ -419,24 +420,24 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
def test_keras_model_init_error(self):
with self.assertRaisesRegexp(ValueError, 'Either'):
- keras.estimator.model_to_estimator()
+ keras_lib.model_to_estimator()
with self.test_session():
keras_model = simple_sequential_model()
with self.assertRaisesRegexp(ValueError, 'not both'):
- keras.estimator.model_to_estimator(
+ keras_lib.model_to_estimator(
keras_model=keras_model,
keras_model_path=tempfile.mkdtemp(dir=self._base_dir))
with self.test_session():
keras_model = simple_sequential_model()
with self.assertRaisesRegexp(ValueError, 'compiled'):
- keras.estimator.model_to_estimator(keras_model=keras_model)
+ keras_lib.model_to_estimator(keras_model=keras_model)
with self.test_session():
keras_model = simple_sequential_model()
with self.assertRaisesRegexp(ValueError, 'not a local path'):
- keras.estimator.model_to_estimator(
+ keras_lib.model_to_estimator(
keras_model_path='gs://bucket/object')
def test_invalid_ionames_error(self):
@@ -460,7 +461,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
model.compile(
loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
with self.test_session():
- est_keras = keras.estimator.model_to_estimator(
+ est_keras = keras_lib.model_to_estimator(
keras_model=model, config=self._config)
with self.test_session():
@@ -479,12 +480,12 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
}
with self.assertRaisesRegexp(ValueError, 'relu6'):
with self.test_session():
- keras.estimator.model_to_estimator(
+ keras_lib.model_to_estimator(
keras_model=keras_mobile,
model_dir=tempfile.mkdtemp(dir=self._base_dir))
with self.test_session():
- keras.estimator.model_to_estimator(
+ keras_lib.model_to_estimator(
keras_model=keras_mobile,
model_dir=tempfile.mkdtemp(dir=self._base_dir),
custom_objects=custom_objects)
@@ -509,7 +510,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
})
with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
with self.test_session():
- keras.estimator.model_to_estimator(
+ keras_lib.model_to_estimator(
keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir))
@@ -524,7 +525,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3)
sess_config = config_pb2.ConfigProto(gpu_options=gpu_options)
self._config._session_config = sess_config
- keras.estimator.model_to_estimator(
+ keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
self.assertEqual(
keras.backend.get_session()
@@ -548,7 +549,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
loss='categorical_crossentropy',
optimizer=SGD(lr=0.0001, momentum=0.9),
metrics=['mse', keras.metrics.categorical_accuracy])
- keras.estimator.model_to_estimator(
+ keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index 8162b249f1..c7707be839 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -27,8 +27,8 @@ import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
-from tensorflow.python.estimator import util
from tensorflow.python.util import compat_internal
+from tensorflow.python.util import function_utils
from tensorflow.python.util.tf_export import tf_export
@@ -283,7 +283,7 @@ def _validate_properties(run_config):
message='tf_random_seed must be integer.')
_validate('device_fn', lambda device_fn: six.callable(device_fn) and
- set(util.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
+ set(function_utils.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
message='device_fn must be callable with exactly'
' one argument "op".')
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py
index bb4bdd3fdf..e4e1d37f74 100644
--- a/tensorflow/python/estimator/util.py
+++ b/tensorflow/python/estimator/util.py
@@ -13,55 +13,21 @@
# limitations under the License.
# ==============================================================================
-"""Utility to retrieve function args."""
+"""Utilities for Estimators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
import os
import time
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
-from tensorflow.python.util import tf_decorator
-from tensorflow.python.util import tf_inspect
-
-
-def _is_bounded_method(fn):
- _, fn = tf_decorator.unwrap(fn)
- return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
-
-
-def _is_callable_object(obj):
- return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__)
-
-
-def fn_args(fn):
- """Get argument names for function-like object.
-
- Args:
- fn: Function, or function-like object (e.g., result of `functools.partial`).
-
- Returns:
- `tuple` of string argument names.
-
- Raises:
- ValueError: if partial function has positionally bound arguments
- """
- if isinstance(fn, functools.partial):
- args = fn_args(fn.func)
- args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
- else:
- if _is_callable_object(fn):
- fn = fn.__call__
- args = tf_inspect.getfullargspec(fn).args
- if _is_bounded_method(fn):
- args.remove('self')
- return tuple(args)
+from tensorflow.python.util import function_utils
+fn_args = function_utils.fn_args
# When we create a timestamped directory, there is a small chance that the
# directory already exists because another process is also creating these
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 40386ae7aa..52e58d7ab5 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -1068,6 +1068,7 @@ def numeric_column(key,
raise TypeError(
'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
+ _assert_key_is_string(key)
return _NumericColumn(
key,
shape=shape,
@@ -1166,6 +1167,13 @@ def _assert_string_or_int(dtype, prefix):
'{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype))
+def _assert_key_is_string(key):
+ if not isinstance(key, six.string_types):
+ raise ValueError(
+ 'key must be a string. Got: type {}. Given key: {}.'.format(
+ type(key), key))
+
+
@tf_export('feature_column.categorical_column_with_hash_bucket')
def categorical_column_with_hash_bucket(key,
hash_bucket_size,
@@ -1218,6 +1226,7 @@ def categorical_column_with_hash_bucket(key,
'hash_bucket_size: {}, key: {}'.format(
hash_bucket_size, key))
+ _assert_key_is_string(key)
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
return _HashedCategoricalColumn(key, hash_bucket_size, dtype)
@@ -1334,6 +1343,7 @@ def categorical_column_with_vocabulary_file(key,
raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
num_oov_buckets, key))
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+ _assert_key_is_string(key)
return _VocabularyFileCategoricalColumn(
key=key,
vocabulary_file=vocabulary_file,
@@ -1448,6 +1458,7 @@ def categorical_column_with_vocabulary_list(
'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
dtype, vocabulary_dtype, key))
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+ _assert_key_is_string(key)
return _VocabularyListCategoricalColumn(
key=key, vocabulary_list=tuple(vocabulary_list), dtype=dtype,
@@ -1518,6 +1529,7 @@ def categorical_column_with_identity(key, num_buckets, default_value=None):
raise ValueError(
'default_value {} not in range [0, {}), column_name {}'.format(
default_value, num_buckets, key))
+ _assert_key_is_string(key)
return _IdentityCategoricalColumn(
key=key, num_buckets=num_buckets, default_value=default_value)
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index b06540489f..03c47eea31 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -182,6 +182,10 @@ class NumericColumnTest(test.TestCase):
self.assertEqual(dtypes.float32, a.dtype)
self.assertIsNone(a.normalizer_fn)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.numeric_column(key=('aaa',))
+
def test_shape_saved_as_tuple(self):
a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]])
self.assertEqual((1, 2), a.shape)
@@ -645,6 +649,10 @@ class HashedCategoricalColumnTest(test.TestCase):
self.assertEqual(10, a.hash_bucket_size)
self.assertEqual(dtypes.string, a.dtype)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_hash_bucket(('key',), 10)
+
def test_bucket_size_should_be_given(self):
with self.assertRaisesRegexp(ValueError, 'hash_bucket_size must be set.'):
fc.categorical_column_with_hash_bucket('aaa', None)
@@ -3327,6 +3335,11 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
'aaa': parsing_ops.VarLenFeature(dtypes.string)
}, column._parse_example_spec)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_vocabulary_file(
+ key=('aaa',), vocabulary_file='path_to_file', vocabulary_size=3)
+
def test_all_constructor_args(self):
column = fc.categorical_column_with_vocabulary_file(
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
@@ -3752,6 +3765,11 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
'aaa': parsing_ops.VarLenFeature(dtypes.string)
}, column._parse_example_spec)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_vocabulary_list(
+ key=('aaa',), vocabulary_list=('omar', 'stringer', 'marlo'))
+
def test_defaults_int(self):
column = fc.categorical_column_with_vocabulary_list(
key='aaa', vocabulary_list=(12, 24, 36))
@@ -4143,6 +4161,10 @@ class IdentityCategoricalColumnTest(test.TestCase):
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, column._parse_example_spec)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_identity(key=('aaa',), num_buckets=3)
+
def test_deep_copy(self):
original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
for column in (original, copy.deepcopy(original)):
diff --git a/tensorflow/python/framework/c_api_util.py b/tensorflow/python/framework/c_api_util.py
index 7bbe3183df..aff289f7be 100644
--- a/tensorflow/python/framework/c_api_util.py
+++ b/tensorflow/python/framework/c_api_util.py
@@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.core.framework import api_def_pb2
+from tensorflow.core.framework import op_def_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.util import compat
from tensorflow.python.util import tf_contextlib
@@ -89,6 +91,50 @@ class ScopedTFFunction(object):
c_api.TF_DeleteFunction(self.func)
+class ApiDefMap(object):
+ """Wrapper around Tf_ApiDefMap that handles querying and deletion.
+
+ The OpDef protos are also stored in this class so that they could
+ be queried by op name.
+ """
+
+ def __init__(self):
+ op_def_proto = op_def_pb2.OpList()
+ buf = c_api.TF_GetAllOpList()
+ try:
+ op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
+ self._api_def_map = c_api.TF_NewApiDefMap(buf)
+ finally:
+ c_api.TF_DeleteBuffer(buf)
+
+ self._op_per_name = {}
+ for op in op_def_proto.op:
+ self._op_per_name[op.name] = op
+
+ def __del__(self):
+ # Note: when we're destructing the global context (i.e when the process is
+ # terminating) we can have already deleted other modules.
+ if c_api is not None and c_api.TF_DeleteApiDefMap is not None:
+ c_api.TF_DeleteApiDefMap(self._api_def_map)
+
+ def put_api_def(self, text):
+ c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text))
+
+ def get_api_def(self, op_name):
+ api_def_proto = api_def_pb2.ApiDef()
+ buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name))
+ try:
+ api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
+ finally:
+ c_api.TF_DeleteBuffer(buf)
+ return api_def_proto
+
+ def get_op_def(self, op_name):
+ if op_name in self._op_per_name:
+ return self._op_per_name[op_name]
+ raise ValueError("No entry found for " + op_name + ".")
+
+
@tf_contextlib.contextmanager
def tf_buffer(data=None):
"""Context manager that creates and deletes TF_Buffer.
diff --git a/tensorflow/python/framework/c_api_util_test.py b/tensorflow/python/framework/c_api_util_test.py
new file mode 100644
index 0000000000..e0bc9ee531
--- /dev/null
+++ b/tensorflow/python/framework/c_api_util_test.py
@@ -0,0 +1,55 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for c_api utils."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class ApiDefMapTest(test_util.TensorFlowTestCase):
+
+ def testApiDefMapGet(self):
+ api_def_map = c_api_util.ApiDefMap()
+ op_def = api_def_map.get_op_def("Add")
+ self.assertEqual(op_def.name, "Add")
+ api_def = api_def_map.get_api_def("Add")
+ self.assertEqual(api_def.graph_op_name, "Add")
+
+ def testApiDefMapPutThenGet(self):
+ api_def_map = c_api_util.ApiDefMap()
+ api_def_text = """
+op {
+ graph_op_name: "Add"
+ summary: "Returns x + y element-wise."
+ description: <<END
+*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+END
+}
+"""
+ api_def_map.put_api_def(api_def_text)
+ api_def = api_def_map.get_api_def("Add")
+ self.assertEqual(api_def.graph_op_name, "Add")
+ self.assertEqual(api_def.summary, "Returns x + y element-wise.")
+
+
+if __name__ == "__main__":
+ googletest.main()
+
diff --git a/tensorflow/python/framework/fast_tensor_util.pyx b/tensorflow/python/framework/fast_tensor_util.pyx
index 19928314ef..17d112a1ec 100644
--- a/tensorflow/python/framework/fast_tensor_util.pyx
+++ b/tensorflow/python/framework/fast_tensor_util.pyx
@@ -7,6 +7,18 @@ cimport numpy as np
from tensorflow.python.util import compat
+def AppendFloat16ArrayToTensorProto(
+ # For numpy, npy_half is a typedef for npy_uint16,
+ # see: https://github.com/numpy/numpy/blob/master/doc/source/reference/c-api.coremath.rst#half-precision-functions
+ # Because np.float16_t dosen't exist in cython, we use uint16_t here.
+ # TODO: Use np.float16_t when cython supports it.
+ tensor_proto, np.ndarray[np.uint16_t, ndim=1] nparray):
+ cdef long i, n
+ n = nparray.size
+ for i in range(n):
+ tensor_proto.half_val.append(nparray[i])
+
+
def AppendFloat32ArrayToTensorProto(
tensor_proto, np.ndarray[np.float32_t, ndim=1] nparray):
cdef long i, n
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index f82e94b1a3..94c37d65c3 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -248,6 +248,9 @@ class _DefinedFunction(object):
# Constructed only when C API is enabled, lazily
self._c_func = None
self._sub_functions = dict() # Constructed with _definition or _c_func
+ device_stack = ops.get_default_graph()._device_function_stack # pylint: disable=protected-access
+ # Get the innermost device if possbile.
+ self._caller_device = device_stack[-1] if device_stack else None
# Cached OpDef for this function. When C API is enabled, this is
# the only part of FunctionDef that we cache in Python. When C API
@@ -313,6 +316,16 @@ class _DefinedFunction(object):
self._create_definition_if_needed()
return self._extra_inputs
+ @property
+ def stateful_ops(self):
+ """Returns the list of stateful ops in function definition.
+
+ Returns:
+ A list of (op.name, op.type) pairs.
+ """
+ self._create_definition_if_needed()
+ return self._stateful_ops
+
def _create_definition_if_needed(self):
"""Creates the function definition if it's not created yet."""
with context.graph_mode():
@@ -325,7 +338,7 @@ class _DefinedFunction(object):
# Create the func_def object.
temp_graph = _FuncGraph(capture_by_value=self._capture_by_value)
- with temp_graph.as_default():
+ with temp_graph.as_default(), ops.device(self._caller_device):
# List of placeholders for the function_def.
inputs = []
for (argname, argtype) in self._args:
@@ -424,6 +437,10 @@ class _DefinedFunction(object):
else:
self._func_name = compat.as_str(self._op_def.name)
+ self._stateful_ops = [(op.name, op.type)
+ for op in temp_graph.get_operations()
+ if op.op_def.is_stateful]
+
def _set_c_attrs(self, attrs):
"""Sets `attrs` as attributes of self._c_func.
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index a5c19f189e..124b1e85f6 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -182,6 +182,8 @@ class FunctionTest(test.TestCase):
def APlus2B(a, b):
return a + b * 2
+ # APlus2B is stateless.
+ self.assertEqual([], APlus2B.stateful_ops)
with ops.Graph().as_default():
call = APlus2B([1.0], [2.0])
self.assertEqual("APlus2B", call.op.name)
@@ -428,6 +430,8 @@ class FunctionTest(test.TestCase):
with ops.control_dependencies([check]):
return x * 2
+ # Foo contains a stateful op (Assert).
+ self.assertEqual([("Assert", "Assert")], Foo.stateful_ops)
g = ops.Graph()
with g.as_default(), self.test_session():
self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0)
@@ -1719,5 +1723,53 @@ class VariableHoistingTest(test.TestCase):
self._testSimpleModel(False, use_resource=True)
+class DevicePlacementTest(test.TestCase):
+
+ def testNoDeviceGraph(self):
+ with ops.Graph().as_default():
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Matmul(a, b):
+ return math_ops.matmul(a, b)
+
+ Matmul(1., 2.)
+
+ gdef = ops.get_default_graph().as_graph_def()
+ self.assertAllEqual(len(gdef.library.function), 1)
+ fdef = gdef.library.function[0]
+
+ for node in fdef.node_def:
+ self.assertAllEqual(node.device, "")
+
+ def testNestedDevices(self):
+ with ops.Graph().as_default(), ops.device("CPU:0"):
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Matmul(a, b):
+ return math_ops.matmul(a, b)
+
+ with ops.device("CPU:1"):
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Divide(a, b):
+ return math_ops.divide(a, b)
+
+ Divide(Matmul(1., 2.), 3.)
+
+ gdef = ops.get_default_graph().as_graph_def()
+ matmul_fdef = [
+ f for f in gdef.library.function if "Matmul" in f.signature.name
+ ]
+ divide_fdef = [
+ f for f in gdef.library.function if "Divide" in f.signature.name
+ ]
+ self.assertAllEqual(len(matmul_fdef), 1)
+ self.assertAllEqual(len(divide_fdef), 1)
+ for node in matmul_fdef[0].node_def:
+ self.assertAllEqual(node.device, "/device:CPU:0")
+ for node in divide_fdef[0].node_def:
+ self.assertAllEqual(node.device, "/device:CPU:1")
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py
index 9a8477debb..535c6017f5 100644
--- a/tensorflow/python/framework/load_library.py
+++ b/tensorflow/python/framework/load_library.py
@@ -58,7 +58,7 @@ def load_op_library(library_filename):
op_list_str = py_tf.TF_GetOpList(lib_handle)
op_list = op_def_pb2.OpList()
op_list.ParseFromString(compat.as_bytes(op_list_str))
- wrappers = py_tf.GetEagerPythonWrappers(op_list_str)
+ wrappers = py_tf.GetPythonWrappers(op_list_str)
# Delete the library handle to release any memory held in C
# that are no longer needed.
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index de3bf0032b..71825e4a50 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -3455,8 +3455,9 @@ class Graph(object):
# the name will still appear in _names_in_use even though the name hasn't
# been used. This is ok, just leave _names_in_use as-is in this case.
# TODO(skyewm): make the C API guarantee no name conflicts.
- if ret.name not in self._names_in_use:
- self._names_in_use[ret.name] = 1
+ name_key = ret.name.lower()
+ if name_key not in self._names_in_use:
+ self._names_in_use[name_key] = 1
self._create_op_helper(ret, compute_device=compute_device)
return ret
@@ -4172,20 +4173,27 @@ class Graph(object):
"""
if self._name_stack:
name = self._name_stack + "/" + name
- i = self._names_in_use.get(name, 0)
- # Increment the number for "name".
+
+ # For the sake of checking for names in use, we treat names as case
+ # insensitive (e.g. foo = Foo).
+ name_key = name.lower()
+ i = self._names_in_use.get(name_key, 0)
+ # Increment the number for "name_key".
if mark_as_used:
- self._names_in_use[name] = i + 1
+ self._names_in_use[name_key] = i + 1
if i > 0:
- base_name = name
- # Make sure the composed name is not already used.
- while name in self._names_in_use:
- name = "%s_%d" % (base_name, i)
+ base_name_key = name_key
+ # Make sure the composed name key is not already used.
+ while name_key in self._names_in_use:
+ name_key = "%s_%d" % (base_name_key, i)
i += 1
- # Mark the composed name as used in case someone wants
+ # Mark the composed name_key as used in case someone wants
# to call unique_name("name_1").
if mark_as_used:
- self._names_in_use[name] = 1
+ self._names_in_use[name_key] = 1
+
+ # Return the new name with the original capitalization of the given name.
+ name = "%s_%d" % (name, i-1)
return name
def get_name_scope(self):
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index c9c1a3d66b..7d6e3bab79 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -1063,6 +1063,15 @@ class NameStackTest(test_util.TensorFlowTestCase):
self.assertEqual("foo_1", g.unique_name("foo"))
self.assertEqual("foo_3", g.unique_name("foo"))
+ def testUniqueNameCaseInsensitivity(self):
+ g = ops.Graph()
+ self.assertEqual("foo", g.unique_name("foo"))
+ self.assertEqual("Foo_1", g.unique_name("Foo"))
+ with g.name_scope("bar"):
+ self.assertEqual("bar/foo", g.unique_name("foo"))
+ with g.name_scope("Bar"):
+ self.assertEqual("Bar_1/foo", g.unique_name("foo"))
+
def testInvalidNameRaisesError(self):
g = ops.Graph()
with g.name_scope(""): # Should not raise
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
index ad6c36b4b1..ec3748b40e 100644
--- a/tensorflow/python/framework/python_op_gen.cc
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
#include "tensorflow/python/framework/python_op_gen.h"
#include <stdio.h>
@@ -26,8 +25,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/framework/tensor.pb_text.h"
-#include "tensorflow/core/framework/tensor.pb.h"
-#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -41,792 +38,913 @@ limitations under the License.
#include "tensorflow/python/framework/python_op_gen_internal.h"
namespace tensorflow {
-namespace python_op_gen_internal {
+namespace {
const int kRightMargin = 78;
-bool IsPythonReserved(const string& s) {
- static const std::set<string>* const kPythonReserved = new std::set<string>(
- {// Keywords in Python, from:
- // import keyword
- // print keyword.kwlist
- "and", "as", "assert", "break", "class", "continue", "def", "del",
- "elif", "else", "except", "exec", "finally", "for", "from", "global",
- "if", "import", "in", "is", "lambda", "not", "or", "pass", "print",
- "raise", "return", "try", "while", "with", "yield",
- // Built-in functions and types in Python, from:
- // [x for x in dir(__builtins__) if not x[0].islower()]
- "ArithmeticError", "AssertionError", "AttributeError", "BaseException",
- "BufferError", "BytesWarning", "DeprecationWarning", "EOFError",
- "Ellipsis", "EnvironmentError", "Exception", "False",
- "FloatingPointError", "FutureWarning", "GeneratorExit", "IOError",
- "ImportError", "ImportWarning", "IndentationError", "IndexError",
- "KeyError", "KeyboardInterrupt", "LookupError", "MemoryError",
- "NameError", "None", "NotImplemented", "NotImplementedError", "OSError",
- "OverflowError", "PendingDeprecationWarning", "ReferenceError",
- "RuntimeError", "RuntimeWarning", "StandardError", "StopIteration",
- "SyntaxError", "SyntaxWarning", "SystemError", "SystemExit", "TabError",
- "True", "TypeError", "UnboundLocalError", "UnicodeDecodeError",
- "UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError",
- "UnicodeWarning", "UserWarning", "ValueError", "Warning",
- "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__",
- "__package__"});
-
- return kPythonReserved->count(s) > 0;
-}
+constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
-bool IsOpWithUnderscorePrefix(const string& s) {
- static const std::set<string>* const kUnderscoreOps = new std::set<string>(
- {// Lowercase built-in functions and types in Python, from:
- // [x for x in dir(__builtins__) if x[0].islower()] except "round".
- // These need to be excluded so they don't conflict with actual built-in
- // functions since we use '*' imports.
- "abs", "all", "any", "apply", "bin", "bool", "buffer", "bytearray",
- "bytes", "callable", "chr", "classmethod", "cmp", "coerce", "compile",
- "complex", "copyright", "credits", "delattr", "dict", "dir", "divmod",
- "enumerate", "eval", "execfile", "exit", "file", "filter", "float",
- "format", "frozenset", "getattr", "globals", "hasattr", "hash", "help",
- "hex", "id", "input", "int", "intern", "isinstance", "issubclass",
- "iter", "len", "license", "list", "locals", "long", "map", "max",
- "memoryview", "min", "next", "object", "oct", "open", "ord", "pow",
- "print", "property", "quit", "range", "raw_input", "reduce", "reload",
- "repr", "reversed", "set", "setattr", "slice", "sorted", "staticmethod",
- "str", "sum", "super", "tuple", "type", "unichr", "unicode", "vars",
- "xrange", "zip",
- // These have the same name as ops defined in Python and might be used
- // incorrectly depending on order of '*' imports.
- // TODO(annarev): reduce usage of '*' imports and remove these from the
- // list.
- "fused_batch_norm", "histogram_fixed_width", "stack",
- "batch_norm_with_global_normalization", "clip_by_value"});
- return kUnderscoreOps->count(s) > 0;
+string AttrVarName(const string& attr_name,
+ std::unordered_map<string, string>* attr_expressions) {
+ const string var = strings::StrCat("_attr_", attr_name);
+ if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
+ return var;
}
-string AvoidPythonReserved(const string& s) {
- if (IsPythonReserved(s)) return strings::StrCat(s, "_");
- return s;
+void AddInferredAttr(const string& indentation, const string& attr_name,
+ const string& value_expression, string* result,
+ std::unordered_map<string, string>* attr_expressions) {
+ strings::StrAppend(result, indentation,
+ AttrVarName(attr_name, attr_expressions), " = ",
+ value_expression, "\n");
}
-// Indent the first line by "initial" spaces and all following lines
-// by "rest" spaces.
-string Indent(int initial, int rest, StringPiece in) {
- // TODO(josh11b): Also word-wrapping?
- string copy(in.data(), in.size());
- str_util::StripTrailingWhitespace(&copy);
- std::vector<string> v = str_util::Split(copy, '\n');
+string VectorToTuple(const std::vector<string>& l) {
+ if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
+ string ret = "(";
+ for (int i = 0; i < l.size(); ++i) {
+ if (i > 0) {
+ strings::StrAppend(&ret, ", ");
+ }
+ strings::StrAppend(&ret, l[i]);
+ }
+ strings::StrAppend(&ret, ")");
+ return ret;
+}
- string result;
- bool first = true;
- for (const string& line : v) {
- if (first) {
- result = strings::StrCat(Spaces(initial), line, "\n");
- first = false;
- } else {
- if (line.empty()) {
- strings::StrAppend(&result, "\n");
+void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
+ const string& var, string* result) {
+ for (int i = 0; i < output_sizes.size(); ++i) {
+ if (!output_sizes[i].empty()) {
+ strings::StrAppend(result, prefix, var, " = ");
+ if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
+ if (i + 1 < output_sizes.size()) {
+ // Special case i == 0 to avoid "0 +" in the generated code.
+ if (i == 0) {
+ strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
+ var, "[", output_sizes[i], ":]");
+ } else {
+ strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
+ output_sizes[i], "]] + ", var, "[", i, " + ",
+ output_sizes[i], ":]");
+ }
} else {
- strings::StrAppend(&result, Spaces(rest), line, "\n");
+ strings::StrAppend(result, "[", var, "[", i, ":]]");
}
+ strings::StrAppend(result, "\n");
}
}
- return result;
}
-// Adds append to *dest, with a space if the first line will be <= width,
-// or a newline otherwise.
-void AppendWithinWidth(string* dest, StringPiece append, int width) {
- auto first_line = append.find('\n');
- if (first_line == string::npos) first_line = append.size();
- if (dest->size() + first_line + 1 /* space */ > static_cast<size_t>(width)) {
- strings::StrAppend(dest, "\n", append);
- } else {
- strings::StrAppend(dest, " ", append);
- }
+string TensorPBString(const TensorProto& pb) {
+ // Note: This gets used in the argument list, and so must survive naive
+ // word wrapping.
+ return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
}
-// Like DataTypeString() but uses the Python names for the
-// float types.
-string PythonDataTypeString(DataType dtype) {
- switch (dtype) {
- case DT_FLOAT:
- return "float32";
- case DT_DOUBLE:
- return "float64";
- default:
- return DataTypeString(dtype);
+const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
+ for (int i = 0; i < api_def.in_arg_size(); ++i) {
+ if (api_def.in_arg(i).name() == name) {
+ return &api_def.in_arg(i);
+ }
}
+ return nullptr;
}
-string TypeString(DataType dtype, bool ref) {
- if (ref) {
- return strings::StrCat("mutable `", PythonDataTypeString(dtype), "`");
- } else {
- return strings::StrCat("`", PythonDataTypeString(dtype), "`");
+class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
+ public:
+ GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
+ const string& function_name)
+ : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) {
+ op_name_ = function_name_;
+ str_util::ConsumePrefix(&op_name_, "_");
}
-}
-
-string TypeListString(const AttrValue& value) {
- string ret;
- for (int t : value.list().type()) {
- if (!ret.empty()) strings::StrAppend(&ret, ", ");
- DataType dtype = static_cast<DataType>(t);
- if (IsRefType(dtype)) {
- strings::StrAppend(&ret, PythonDataTypeString(RemoveRefType(dtype)),
- " mutable");
+ ~GenEagerPythonOp() override {}
+
+ string Code() override;
+
+ protected:
+ void HandleGraphMode(const string& function_setup);
+
+ string GetEagerNotAllowedError();
+ void ExpectListArg(const string& indentation, const string& arg_name,
+ string* output);
+ bool GetEagerFunctionSetup(const string& indentation, string* function_setup);
+ void GetOutputSizesAndNumOutputsExpr(std::vector<string>* output_sizes,
+ string* num_outputs_expr);
+
+ void AddEagerFunctionTeardown(const string& indentation,
+ const std::vector<string>& output_sizes,
+ bool execute_record_gradient);
+
+ bool AddEagerFastPathAndGraphCode(const string& parameters,
+ const std::vector<string>& output_sizes,
+ const string& eager_not_allowed_error);
+ bool AddEagerFallbackCode(const string& parameters,
+ const std::vector<string>& output_sizes,
+ const string& num_outputs_expr,
+ const string& eager_not_allowed_error);
+ void AddEagerFastPathExecute();
+
+ void AddEagerInferredAttrs(const string& indentation);
+ void AddEagerInputCasts(const string& indentation);
+ void AddEagerAttrs(const string& indentation);
+ void AddEagerExecute(const string& indentation,
+ const string& num_outputs_expr);
+
+ void AddAttrForArg(const string& attr, int arg_index) {
+ gtl::InsertIfNotPresent(&inferred_attrs_, attr,
+ op_def_.input_arg(arg_index).name());
+ auto iter = attr_to_args_.find(attr);
+ if (iter == attr_to_args_.end()) {
+ attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
} else {
- strings::StrAppend(&ret, "`", PythonDataTypeString(dtype), "`");
+ iter->second.push_back(arg_index);
}
}
- return ret;
-}
-string SingleTensorName(DataType dtype, bool is_ref) {
- const string type_str = TypeString(dtype, is_ref);
- return strings::StrCat("A `Tensor` of type ", type_str, ".");
-}
+ // Returns a string expression representing a flattened list of all
+ // the inputs given by `*input_indices` (or all inputs if
+ // `input_indices` is nullptr). `*output_sizes` can be used to unflatten.
+ string FlattenInputs(const std::vector<int>* input_indices,
+ std::vector<string>* output_sizes) const;
-const char kUnknownTensorType[] = {"A `Tensor`."};
-
-string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg,
- const std::unordered_map<string, string>& inferred_attrs,
- bool is_output) {
- if (!arg.number_attr().empty()) {
- // N Tensors with the same type
- const string* original_arg =
- gtl::FindOrNull(inferred_attrs, arg.number_attr());
- string prefix;
- if (original_arg == nullptr) {
- prefix = strings::StrCat("A list of `", arg.number_attr(), "`");
- } else if (*original_arg == arg.name()) {
- const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
- if (attr->has_minimum() && attr->minimum() > 0) {
- prefix = strings::StrCat("A list of at least ", attr->minimum());
- } else {
- prefix = "A list of";
- }
- } else {
- prefix = strings::StrCat("A list with the same length as `",
- AvoidPythonReserved(*original_arg), "` of");
- }
+ StringPiece op_name_;
+ typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
+ AttrToArgMap attr_to_args_;
+ std::unordered_map<string, string> attr_expressions_;
+ // This has all the input args followed by those attrs that don't have
+ // defaults.
+ std::vector<python_op_gen_internal::ParamNames> params_no_default_;
+ // The parameters with defaults (these have to be listed after those without).
+ // No input args are included, just attrs.
+ std::vector<std::pair<python_op_gen_internal::ParamNames, string>>
+ params_with_default_;
+};
- if (arg.type() != DT_INVALID) {
- return strings::StrCat(prefix, " `Tensor` objects with type ",
- TypeString(arg.type(), arg.is_ref()), ".");
- } else {
- original_arg = gtl::FindOrNull(inferred_attrs, arg.type_attr());
- if (arg.is_ref()) {
- strings::StrAppend(&prefix, " mutable");
+string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
+ const string& function_name) {
+ return GenEagerPythonOp(op_def, api_def, function_name).Code();
+}
+
+string GenEagerPythonOp::FlattenInputs(
+ const std::vector<int>* input_indices,
+ std::vector<string>* output_sizes) const {
+ string inputs;
+ enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
+ const int n = input_indices != nullptr ? input_indices->size()
+ : op_def_.input_arg_size();
+ for (int j = 0; j < n; ++j) {
+ const int i = input_indices ? (*input_indices)[j] : j;
+ const auto& arg(op_def_.input_arg(i));
+ const bool is_list =
+ !arg.type_list_attr().empty() || !arg.number_attr().empty();
+ if (is_list) {
+ if (inputs_state == WAS_SOLO_INPUT) {
+ strings::StrAppend(&inputs, "] + ");
+ } else if (inputs_state == WAS_LIST_INPUT) {
+ strings::StrAppend(&inputs, " + ");
}
- if (original_arg == nullptr) {
- return strings::StrCat(prefix, " `Tensor` objects with type `",
- arg.type_attr(), "`.");
- } else if (*original_arg == arg.name()) {
- const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
- if (attr->has_allowed_values()) {
- return strings::StrCat(prefix,
- " `Tensor` objects with the same type in: ",
- TypeListString(attr->allowed_values()), ".");
+ strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")");
+ inputs_state = WAS_LIST_INPUT;
+ if (output_sizes != nullptr) {
+ if (!arg.number_attr().empty()) {
+ output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
} else {
- return strings::StrCat(prefix,
- " `Tensor` objects with the same type.");
+ output_sizes->emplace_back(
+ strings::StrCat("len(", param_names_[i].GetRenameTo(), ")"));
}
- } else {
- return strings::StrCat(prefix,
- " `Tensor` objects with the same type as `",
- AvoidPythonReserved(*original_arg), "`.");
}
- }
- } else if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) {
- const bool is_list = !arg.type_list_attr().empty();
- const string attr_name = is_list ? arg.type_list_attr() : arg.type_attr();
- const OpDef::AttrDef* attr = FindAttr(attr_name, op_def);
- const string mutable_str = arg.is_ref() ? "mutable " : "";
- const string prefix =
- is_list ? strings::StrCat("A list of ", mutable_str, "`Tensor` objects")
- : strings::StrCat("A ", mutable_str, "`Tensor`");
- const string* original_arg = gtl::FindOrNull(inferred_attrs, attr_name);
- if (original_arg == nullptr) {
- return strings::StrCat(prefix, " of type `", attr_name, "`.");
- } else if (*original_arg == arg.name()) {
- if (attr->has_allowed_values()) {
- if (is_list) {
- return strings::StrCat(prefix, " with types from: ",
- TypeListString(attr->allowed_values()), ".");
- } else {
- return strings::StrCat(
- prefix, is_output ? ". Has one of the following types: "
- : ". Must be one of the following types: ",
- TypeListString(attr->allowed_values()), ".");
- }
+ } else {
+ if (inputs_state == WAS_SOLO_INPUT) {
+ strings::StrAppend(&inputs, ", ");
+ } else if (inputs_state == WAS_LIST_INPUT) {
+ strings::StrAppend(&inputs, " + [");
} else {
- return strings::StrCat(prefix, ".");
+ strings::StrAppend(&inputs, "[");
}
- } else {
- return strings::StrCat(prefix,
- is_output ? ". Has the same type as `"
- : ". Must have the same type as `",
- AvoidPythonReserved(*original_arg), "`.");
+ strings::StrAppend(&inputs, param_names_[i].GetRenameTo());
+ inputs_state = WAS_SOLO_INPUT;
+ if (output_sizes != nullptr) output_sizes->emplace_back();
}
- } else {
- return SingleTensorName(arg.type(), arg.is_ref());
}
+ if (inputs_state == STARTING) return "[]";
+ if (inputs_state == WAS_SOLO_INPUT) {
+ strings::StrAppend(&inputs, "]");
+ }
+ return inputs;
}
-string GetReturns(const OpDef& op_def,
- const std::vector<string>& output_type_string) {
- string result;
- DCHECK_EQ(op_def.output_arg_size(), output_type_string.size());
- const int num_outs = op_def.output_arg_size();
- strings::StrAppend(&result, "\n Returns:\n");
- if (num_outs == 0) {
- strings::StrAppend(&result, " The created Operation.\n");
- } else {
- if (num_outs == 1) {
- StringPiece description = op_def.output_arg(0).description();
- if (ConsumeEquals(&description)) { // Skip the generated type info.
- strings::StrAppend(&result, Indent(4, 4, description));
- } else {
- // Special case of one output, don't use the name of the output unless
- // there is no description.
- string desc = output_type_string.empty() ? kUnknownTensorType
- : output_type_string[0];
- if (desc == kUnknownTensorType) {
- // Special case where we don't understand how the output tensor type
- // depends on the input tensor types, just use the output arg
- // description if we can.
- if (!description.empty()) {
- desc = op_def.output_arg(0).description();
- } else if (!op_def.output_arg(0).name().empty()) {
- desc = strings::StrCat(" The ", op_def.output_arg(0).name(),
- " `Tensor`.");
+string GenEagerPythonOp::Code() {
+ if (api_def_.visibility() == ApiDef::SKIP) {
+ return "";
+ }
+
+ for (int i = 0; i < api_def_.arg_order_size(); ++i) {
+ const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
+ const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
+ params_no_default_.emplace_back(api_def_arg.name(),
+ api_def_arg.rename_to());
+ if (!arg.type_attr().empty()) {
+ AddAttrForArg(arg.type_attr(), i);
+ } else if (!arg.type_list_attr().empty()) {
+ AddAttrForArg(arg.type_list_attr(), i);
+ }
+ if (!arg.number_attr().empty()) {
+ AddAttrForArg(arg.number_attr(), i);
+ }
+ }
+ for (int i = 0; i < op_def_.attr_size(); ++i) {
+ const auto& attr(op_def_.attr(i));
+ const auto& api_def_attr(api_def_.attr(i));
+ // Do not add inferred attrs to the Python function signature.
+ if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
+ if (api_def_attr.has_default_value()) {
+ if (attr.type() == "tensor") {
+ params_with_default_.emplace_back(
+ python_op_gen_internal::ParamNames(api_def_attr.name(),
+ api_def_attr.rename_to()),
+ strings::StrCat(
+ "_execute.make_tensor(",
+ TensorPBString(api_def_attr.default_value().tensor()), ", \"",
+ api_def_attr.rename_to(), "\")"));
+ } else if (attr.type() == "list(tensor)") {
+ std::vector<string> pbtxt;
+ for (const auto& pb : api_def_attr.default_value().list().tensor()) {
+ pbtxt.emplace_back(TensorPBString(pb));
}
- } else if (!description.empty()) {
- AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
- }
- strings::StrAppend(&result, Indent(4, 4, desc));
- }
- } else {
- std::vector<string> out_names(num_outs);
- for (int i = 0; i < num_outs; ++i) {
- if (!op_def.output_arg(i).name().empty()) {
- out_names[i] = op_def.output_arg(i).name();
- } else {
- out_names[i] = strings::StrCat("output", i);
- }
- }
- strings::StrAppend(&result, " A tuple of `Tensor` objects (",
- str_util::Join(out_names, ", "), ").\n\n");
- for (int i = 0; i < num_outs; ++i) {
- string desc = strings::StrCat(out_names[i], ": ");
- StringPiece description = op_def.output_arg(i).description();
- if (ConsumeEquals(&description)) { // Skip the generated type info.
- strings::StrAppend(&desc, description);
+ params_with_default_.emplace_back(
+ python_op_gen_internal::ParamNames(api_def_attr.name(),
+ api_def_attr.rename_to()),
+ strings::StrCat("[_execute.make_tensor(_pb, \"",
+ api_def_attr.rename_to(), "\") for _pb in ",
+ VectorToTuple(pbtxt), "]"));
} else {
- const string type = static_cast<size_t>(i) < output_type_string.size()
- ? output_type_string[i]
- : kUnknownTensorType;
- if (!description.empty()) {
- if (type == kUnknownTensorType) {
- // Special case where we don't understand how the output tensor
- // type depends on the input tensor types, so we just use the
- // output arg description.
- strings::StrAppend(&desc, description);
- } else {
- strings::StrAppend(&desc, type, " ", description);
- }
- } else {
- strings::StrAppend(&desc, type);
- }
+ params_with_default_.emplace_back(
+ python_op_gen_internal::ParamNames(api_def_attr.name(),
+ api_def_attr.rename_to()),
+ python_op_gen_internal::AttrValueToPython(
+ attr.type(), api_def_attr.default_value(), "_dtypes."));
}
- strings::StrAppend(&result, Indent(4, 6, desc));
+ } else {
+ params_no_default_.emplace_back(api_def_attr.name(),
+ api_def_attr.rename_to());
}
}
}
- return result;
-}
-string StringToPython(const string& str) {
- return strings::StrCat("\"", str_util::CEscape(str), "\"");
-}
+ // Save the list of attr parameters (attrs that won't be inferred),
+ // those with defaults go at the end.
+ // Get the attrs in the order we want by taking the attrs without defaults
+ // from the end of params_no_default_, and adding params_no_default_.
+ attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() +
+ params_with_default_.size());
+ for (int i = op_def_.input_arg_size(); i < params_no_default_.size(); ++i) {
+ attrs_.push_back(params_no_default_[i].GetName());
+ }
+ for (const auto& p : params_with_default_) {
+ attrs_.push_back(p.first.GetName());
+ }
-string DataTypeToPython(DataType dtype, const string& dtype_module) {
- return strings::StrCat(dtype_module, PythonDataTypeString(dtype));
-}
+ param_names_.reserve(params_no_default_.size() + params_with_default_.size());
+ param_names_.insert(param_names_.begin(), params_no_default_.begin(),
+ params_no_default_.end());
+ for (const auto& param_and_default : params_with_default_) {
+ param_names_.push_back(param_and_default.first);
+ }
-string ShapeToPython(const TensorShapeProto& shape) {
- if (shape.unknown_rank()) {
- return "None";
+ string parameters;
+ for (const auto& param : params_no_default_) {
+ if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
+ strings::StrAppend(&parameters, param.GetRenameTo());
}
- string python = "[";
- for (const auto& dim : shape.dim()) {
- if (python.size() > 1) strings::StrAppend(&python, ", ");
- if (!dim.name().empty()) {
- strings::StrAppend(&python, "(", StringToPython(dim.name()), ", ",
- dim.size(), ")");
- } else {
- strings::StrAppend(&python, dim.size());
+ for (const auto& param_and_default : params_with_default_) {
+ if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
+ strings::StrAppend(&parameters, param_and_default.first.GetRenameTo(), "=",
+ param_and_default.second);
+ }
+ if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
+ strings::StrAppend(&parameters, "name=None");
+
+ // Add attr_expressions_ for attrs that are params.
+ for (int i = 0; i < attrs_.size(); ++i) {
+ const string& attr_name = attrs_[i];
+ const string& attr_api_name =
+ param_names_[i + op_def_.input_arg_size()].GetRenameTo();
+ attr_expressions_[attr_name] = attr_api_name;
+ }
+ // Add attr_expressions_ for attrs that are inferred.
+ for (int i = 0; i < op_def_.attr_size(); ++i) {
+ const auto& attr(op_def_.attr(i));
+ if (attr.type() == "int") {
+ auto arg_list = attr_to_args_.find(attr.name());
+ if (arg_list != attr_to_args_.end()) {
+ AttrVarName(attr.name(), &attr_expressions_);
+ }
}
}
- strings::StrAppend(&python, "]");
- return python;
-}
-string TensorToPython(const TensorProto& proto) {
- return ProtoShortDebugString(proto);
-}
+ string num_outputs_expr;
+ std::vector<string> output_sizes(num_outs_);
+ GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr);
-string AttrListToPython(const AttrValue& value,
- const string& dtype_module = "tf.") {
- string ret;
- if (value.list().s_size() > 0) {
- for (int i = 0; i < value.list().s_size(); ++i) {
- if (i > 0) strings::StrAppend(&ret, ", ");
- strings::StrAppend(&ret, StringToPython(value.list().s(i)));
- }
- } else if (value.list().i_size() > 0) {
- for (int i = 0; i < value.list().i_size(); ++i) {
- if (i > 0) strings::StrAppend(&ret, ", ");
- strings::StrAppend(&ret, value.list().i(i));
- }
- } else if (value.list().f_size() > 0) {
- for (int i = 0; i < value.list().f_size(); ++i) {
- if (i > 0) strings::StrAppend(&ret, ", ");
- strings::StrAppend(&ret, value.list().f(i));
- }
- } else if (value.list().b_size() > 0) {
- for (int i = 0; i < value.list().b_size(); ++i) {
- if (i > 0) strings::StrAppend(&ret, ", ");
- strings::StrAppend(&ret, value.list().b(i) ? "True" : "False");
- }
- } else if (value.list().type_size() > 0) {
- for (int i = 0; i < value.list().type_size(); ++i) {
- if (i > 0) strings::StrAppend(&ret, ", ");
- strings::StrAppend(&ret,
- DataTypeToPython(value.list().type(i), dtype_module));
- }
- } else if (value.list().shape_size() > 0) {
- for (int i = 0; i < value.list().shape_size(); ++i) {
- if (i > 0) strings::StrAppend(&ret, ", ");
- strings::StrAppend(&ret, ShapeToPython(value.list().shape(i)));
- }
- } else if (value.list().tensor_size() > 0) {
- for (int i = 0; i < value.list().tensor_size(); ++i) {
- if (i > 0) strings::StrAppend(&ret, ", ");
- strings::StrAppend(&ret, TensorToPython(value.list().tensor(i)));
- }
- } else if (value.list().func_size() > 0) {
- for (int i = 0; i < value.list().func_size(); ++i) {
- if (i > 0) strings::StrAppend(&ret, ", ");
- strings::StrAppend(&ret, StringToPython(value.list().func(i).name()));
- }
+ string eager_not_allowed_error = GetEagerNotAllowedError();
+
+ if (!AddEagerFastPathAndGraphCode(parameters, output_sizes,
+ eager_not_allowed_error)) {
+ return result_;
}
- return ret;
+
+ if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr,
+ eager_not_allowed_error)) {
+ return result_;
+ }
+
+ return prelude_ + result_;
}
-// NOTE: The return value may contain spaces (for example, it could be
-// a string "foo bar" with an embedded space) and is not safe to pass
-// to WordWrap().
-string AttrValueToPython(const string& type, const AttrValue& value,
- const string& dtype_module) {
- if (type == "string") {
- return StringToPython(value.s());
- } else if (type == "int") {
- return strings::StrCat(value.i());
- } else if (type == "float") {
- if (std::isnan(value.f()) || std::isinf(value.f())) {
- return strings::StrCat("float('", value.f(), "')");
+void GenEagerPythonOp::HandleGraphMode(const string& function_setup) {
+ // Handle graph-mode case
+ strings::StrAppend(&result_,
+ " _ctx = _context._context\n"
+ " if _ctx is None or not _ctx._eager_context.is_eager:\n",
+ function_setup,
+ " _, _, _op = _op_def_lib._apply_op_helper(\n");
+ AddBodyNoReturn(" ");
+ if (num_outs_ > 0) {
+ strings::StrAppend(&result_, " _result = _op.outputs[:]\n");
+ // Special case handling for stateful op with single list output
+ // that might be empty.
+ if (num_outs_ == 1 && op_def_.is_stateful() &&
+ (!op_def_.output_arg(0).number_attr().empty() ||
+ !op_def_.output_arg(0).type_list_attr().empty())) {
+ // TODO(josh11b): Can skip this if the number_attr/type_list_attr has
+ // a constraint indicating that this can never be empty.
+ strings::StrAppend(&result_,
+ " if not _result:\n"
+ " return _op\n");
+ }
+ strings::StrAppend(&result_, " _inputs_flat = _op.inputs\n");
+
+ // Compute graph-mode attrs.
+ if (op_def_.attr_size() > 0) {
+ string attr_values;
+ for (int i = 0; i < op_def_.attr_size(); ++i) {
+ if (i > 0) strings::StrAppend(&attr_values, ", ");
+ const auto& attr_name(op_def_.attr(i).name());
+ strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"",
+ attr_name, "\")");
+ }
+ strings::StrAppend(&attr_values, ")");
+ strings::StrAppend(&result_,
+ WordWrap(" _attrs = (", attr_values, kRightMargin),
+ "\n");
} else {
- return strings::StrCat(value.f());
+ strings::StrAppend(&result_, " _attrs = None\n");
}
- } else if (type == "bool") {
- return value.b() ? "True" : "False";
- } else if (type == "type") {
- return DataTypeToPython(value.type(), dtype_module);
- } else if (type == "shape") {
- return ShapeToPython(value.shape());
- } else if (type == "tensor") {
- return TensorToPython(value.tensor());
- } else if (type == "func") {
- return StringToPython(value.func().name());
- } else if (str_util::StartsWith(type, "list(")) {
- return strings::StrCat("[", AttrListToPython(value, dtype_module), "]");
} else {
- return "?";
+ strings::StrAppend(&result_, " return _op\n");
}
}
-void GenerateLowerCaseOpName(const string& str, string* result) {
- const char joiner = '_';
- const int last_index = str.size() - 1;
- for (int i = 0; i <= last_index; ++i) {
- const char c = str[i];
- // Emit a joiner only if a previous-lower-to-now-upper or a
- // now-upper-to-next-lower transition happens.
- if (isupper(c) && (i > 0)) {
- if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) {
- result->push_back(joiner);
- }
+string GenEagerPythonOp::GetEagerNotAllowedError() {
+ bool eager_allowed = true;
+ string ref_arg;
+ for (int i = 0; i < op_def_.input_arg_size(); ++i) {
+ const auto& arg = op_def_.input_arg(i);
+ if (arg.is_ref()) {
+ eager_allowed = false;
+ DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name());
+ ref_arg = api_def_.in_arg(i).rename_to();
+ }
+ }
+ for (int i = 0; i < op_def_.output_arg_size(); ++i) {
+ const auto& arg = op_def_.output_arg(i);
+ if (arg.is_ref()) {
+ eager_allowed = false;
+ DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name());
+ ref_arg = api_def_.out_arg(i).rename_to();
}
- result->push_back(tolower(c));
}
+
+ if (eager_allowed) return "";
+
+ return strings::StrCat("raise RuntimeError(\"", op_name_,
+ " op does not support eager execution. ", "Arg '",
+ ref_arg, "' is a ref.\")\n");
}
-static void AddDelimiter(string* append_to, const string& delim) {
- if (!append_to->empty()) strings::StrAppend(append_to, delim);
+void GenEagerPythonOp::ExpectListArg(const string& indentation,
+ const string& arg_name, string* output) {
+ strings::StrAppend(output, indentation, "if not isinstance(", arg_name,
+ ", (list, tuple)):\n", indentation, " raise TypeError(\n",
+ indentation, " \"Expected list for '", arg_name,
+ "' argument to \"\n", indentation, " \"'", op_name_,
+ "' Op, not %r.\" % ", arg_name, ")\n");
}
-const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) {
- for (int i = 0; i < api_def.attr_size(); ++i) {
- if (api_def.attr(i).name() == name) {
- return &api_def.attr(i);
+bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation,
+ string* function_setup) {
+ // Validate list inputs, infer length attrs.
+ for (int i = 0; i < op_def_.attr_size(); ++i) {
+ const auto& attr(op_def_.attr(i));
+ if (attr.type() == "int") {
+ auto arg_list = attr_to_args_.find(attr.name());
+ if (arg_list != attr_to_args_.end()) {
+ // Inferred int attrs are the lengths of inputs. Validate those
+ // inputs are lists and have the same length.
+ for (auto iter = arg_list->second.begin();
+ iter != arg_list->second.end(); ++iter) {
+ const string& arg_api_name = param_names_[*iter].GetRenameTo();
+ ExpectListArg(indentation, arg_api_name, function_setup);
+ if (iter == arg_list->second.begin()) {
+ AddInferredAttr(indentation, attr.name(),
+ strings::StrCat("len(", arg_api_name, ")"),
+ function_setup, &attr_expressions_);
+ } else {
+ const auto& attr_var = attr_expressions_[attr.name()];
+ strings::StrAppend(
+ function_setup, indentation, "if len(", arg_api_name,
+ ") != ", attr_var, ":\n", indentation, " raise ValueError(\n",
+ indentation, " \"List argument '", arg_api_name, "' to '",
+ op_name_, "' Op with length %d \"\n", indentation,
+ " \"must match length %d of argument '",
+ inferred_attrs_[attr.name()], "'.\" %\n", indentation,
+ " (len(", arg_api_name, "), ", attr_var, "))\n");
+ }
+ }
+ }
}
}
- return nullptr;
-}
-const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
- for (int i = 0; i < api_def.in_arg_size(); ++i) {
- if (api_def.in_arg(i).name() == name) {
- return &api_def.in_arg(i);
+ for (int i = 0; i < attrs_.size(); ++i) {
+ const string& attr_name = attrs_[i];
+ const auto& param = param_names_[i + op_def_.input_arg_size()];
+ const auto& attr = *FindAttr(attr_name, op_def_);
+ const string& attr_api_name = param.GetRenameTo();
+ StringPiece attr_type = attr.type();
+ attr_expressions_[attr_name] = attr_api_name;
+ const int default_index = i - (attrs_.size() - params_with_default_.size());
+ if (default_index >= 0) {
+ const string& default_value = params_with_default_[default_index].second;
+ strings::StrAppend(function_setup, indentation, "if ", attr_api_name,
+ " is None:\n");
+ strings::StrAppend(function_setup, indentation, " ", attr_api_name,
+ " = ", default_value, "\n");
+ }
+ if (str_util::StartsWith(attr_type, "list(")) {
+ ExpectListArg(indentation, attr_api_name, function_setup);
+ }
+
+ if (attr_type == "string") {
+ strings::StrAppend(function_setup, indentation, attr_api_name,
+ " = _execute.make_str(", attr_api_name, ", \"",
+ attr_api_name, "\")\n");
+ } else if (attr_type == "list(string)") {
+ strings::StrAppend(function_setup, indentation, attr_api_name,
+ " = [_execute.make_str(_s, \"", attr_api_name,
+ "\") for _s in ", attr_api_name, "]\n");
+ } else if (attr_type == "int") {
+ strings::StrAppend(function_setup, indentation, attr_api_name,
+ " = _execute.make_int(", attr_api_name, ", \"",
+ attr_api_name, "\")\n");
+ } else if (attr_type == "list(int)") {
+ strings::StrAppend(function_setup, indentation, attr_api_name,
+ " = [_execute.make_int(_i, \"", attr_api_name,
+ "\") for _i in ", attr_api_name, "]\n");
+ } else if (attr_type == "float") {
+ strings::StrAppend(function_setup, indentation, attr_api_name,
+ " = _execute.make_float(", attr_api_name, ", \"",
+ attr_api_name, "\")\n");
+ } else if (attr_type == "list(float)") {
+ strings::StrAppend(function_setup, indentation, attr_api_name,
+ " = [_execute.make_float(_f, \"", attr_api_name,
+ "\") for _f in ", attr_api_name, "]\n");
+ } else if (attr_type == "bool") {
+ strings::StrAppend(function_setup, indentation, attr_api_name,
+ " = _execute.make_bool(", attr_api_name, ", \"",
+ attr_api_name, "\")\n");
+ } else if (attr_type == "list(bool)") {
+ strings::StrAppend(function_setup, indentation, attr_api_name,
+ " = [_execute.make_bool(_b, \"", attr_api_name,
+ "\") for _b in ", attr_api_name, "]\n");
+ } else if (attr_type == "type") {
+ strings::StrAppend(function_setup, indentation, attr_api_name,
+ " = _execute.make_type(", attr_api_name, ", \"",
+ attr_api_name, "\")\n");
+ } else if (attr_type == "list(type)") {
+ strings::StrAppend(function_setup, indentation, attr_api_name,
+ " = [_execute.make_type(_t, \"", attr_api_name,
+ "\") for _t in ", attr_api_name, "]\n");
+ } else if (attr_type == "shape") {
+ strings::StrAppend(function_setup, indentation, attr_api_name,
+ " = _execute.make_shape(", attr_api_name, ", \"",
+ attr_api_name, "\")\n");
+ } else if (attr_type == "list(shape)") {
+ strings::StrAppend(function_setup, indentation, attr_api_name,
+ " = [_execute.make_shape(_s, \"", attr_api_name,
+ "\") for _s in ", attr_api_name, "]\n");
+ } else if (attr_type == "tensor") {
+ strings::StrAppend(function_setup, indentation, attr_api_name,
+ " = _execute.make_tensor(", attr_api_name, ", \"",
+ attr_api_name, "\")\n");
+ } else if (attr_type == "list(tensor)") {
+ strings::StrAppend(function_setup, indentation, attr_api_name,
+ " = [_execute.make_tensor(_t, \"", attr_api_name,
+ "\") for _t in ", attr_api_name, "]\n");
+ } else if (attr_type != "func") {
+ *function_setup =
+ strings::StrCat("# No definition for ", function_name_,
+ " since we don't support attrs with type\n"
+ "# '",
+ attr_type, "' right now.\n\n");
+ return false;
}
}
- return nullptr;
+ return true;
}
-GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
- const string& function_name)
- : op_def_(op_def),
- api_def_(api_def),
- function_name_(function_name),
- num_outs_(op_def.output_arg_size()) {}
-
-GenPythonOp::~GenPythonOp() {}
-
-string GenPythonOp::Code() {
- // This has all the input args followed by those attrs that don't have
- // defaults.
- std::vector<ParamNames> params_no_default;
- // The parameters with defaults (these have to be listed after those without).
- // No input args are included, just attrs.
- std::vector<ParamNames> params_with_default;
-
- for (int i = 0; i < api_def_.arg_order_size(); ++i) {
- const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
- const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
- params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to());
- if (!arg.type_attr().empty()) {
- gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name());
- } else if (!arg.type_list_attr().empty()) {
- gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_list_attr(),
- arg.name());
- }
+// If output i is list output, output_sizes[i] will be set to a
+// string with the python expression that will evaluate to its
+// length. output_sizes[i] is empty for non-list outputs.
+void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr(
+ std::vector<string>* output_sizes, string* num_outputs_expr) {
+ // Expression representing the number of outputs.
+ int num_fixed_outputs = 0;
+ for (int i = 0; i < num_outs_; ++i) {
+ const auto& arg(op_def_.output_arg(i));
if (!arg.number_attr().empty()) {
- gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name());
- }
- }
- for (int i = 0; i < api_def_.attr_size(); ++i) {
- const auto& attr(api_def_.attr(i));
- // Do not add inferred attrs to the Python function signature.
- if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
- if (attr.has_default_value()) {
- params_with_default.emplace_back(attr.name(), attr.rename_to());
+ if (!num_outputs_expr->empty()) {
+ strings::StrAppend(num_outputs_expr, " + ");
+ }
+ (*output_sizes)[i] = attr_expressions_[arg.number_attr()];
+ strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
+ } else if (!arg.type_list_attr().empty()) {
+ if (!num_outputs_expr->empty()) {
+ strings::StrAppend(num_outputs_expr, " + ");
+ }
+ // Have to be careful to use an expression that works in both
+ // graph and eager paths here.
+ const auto iter = inferred_attrs_.find(arg.type_list_attr());
+ if (iter == inferred_attrs_.end()) {
+ (*output_sizes)[i] = strings::StrCat(
+ "len(", attr_expressions_[arg.type_list_attr()], ")");
} else {
- params_no_default.emplace_back(attr.name(), attr.rename_to());
+ (*output_sizes)[i] = strings::StrCat("len(", iter->second, ")");
}
+ strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
+ } else {
+ ++num_fixed_outputs;
}
}
-
- // Save the list of attr parameters (attrs that won't be inferred),
- // those with defaults go at the end.
- // Get the attrs in the order we want by taking the attrs without defaults
- // from the end of args_no_default, and adding args_no_default.
- attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() +
- params_with_default.size());
- for (int i = op_def_.input_arg_size(); i < params_no_default.size(); ++i) {
- attrs_.push_back(params_no_default[i].GetName());
- }
- for (int i = 0; i < params_with_default.size(); ++i) {
- attrs_.push_back(params_with_default[i].GetName());
- }
-
- param_names_.reserve(params_no_default.size() + params_with_default.size());
- param_names_.insert(param_names_.begin(), params_no_default.begin(),
- params_no_default.end());
- for (const auto& param : params_with_default) {
- param_names_.push_back(param);
+ if (num_fixed_outputs > 0) {
+ if (!num_outputs_expr->empty()) {
+ strings::StrAppend(num_outputs_expr, " + ");
+ }
+ strings::StrAppend(num_outputs_expr, num_fixed_outputs);
+ } else if (num_outputs_expr->empty()) {
+ *num_outputs_expr = "0";
}
+}
- string parameters;
- for (const auto& param : params_no_default) {
- AddDelimiter(&parameters, ", ");
- strings::StrAppend(&parameters, param.GetRenameTo());
- }
- for (const auto& param_and_default : params_with_default) {
- AddDelimiter(&parameters, ", ");
- strings::StrAppend(&parameters, param_and_default.GetRenameTo(), "=None");
+void GenEagerPythonOp::AddEagerFunctionTeardown(
+ const string& indentation, const std::vector<string>& output_sizes,
+ bool execute_record_gradient) {
+ if (num_outs_ > 0) {
+ if (execute_record_gradient) {
+ strings::StrAppend(&result_, indentation, "_execute.record_gradient(\n",
+ " \"", op_def_.name(),
+ "\", _inputs_flat, _attrs, _result, name)\n");
+ }
+ if (num_outs_ == 1 && !output_sizes[0].empty()) {
+ // Single list result.
+ } else if (num_outs_ == 1) {
+ // Execute returns a single-element list which we need to destructure.
+ strings::StrAppend(&result_, indentation, "_result, = _result\n");
+ } else {
+ // Have multiple outputs, so we will need to reformat the return
+ // value of execute() to be a list with one entry per op output
+ // (that entry will be a list of tensors if that output is of list
+ // type).
+ // For list outputs, convert the right subrange of _result into a list.
+ Unflatten(indentation, output_sizes, "_result", &result_);
+ // Convert to a named tuple.
+ strings::StrAppend(&result_, indentation, "_result = _", op_def_.name(),
+ "Output._make(_result)\n");
+ }
+ } else {
+ strings::StrAppend(&result_, indentation, "_result = None\n");
}
- AddDelimiter(&parameters, ", ");
- strings::StrAppend(&parameters, "name=None");
+ strings::StrAppend(&result_, indentation, "return _result\n\n");
+}
+bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
+ const string& parameters, const std::vector<string>& output_sizes,
+ const string& eager_not_allowed_error) {
AddExport();
- AddDefLine(parameters);
+ AddDefLine(function_name_, parameters);
AddDocStringDescription();
AddDocStringArgs();
AddDocStringInputs();
AddDocStringAttrs();
AddDocStringNameArg();
- AddOutputGlobals();
+ AddOutputGlobals(); // Added to prelude_
AddDocStringOutputs();
strings::StrAppend(&result_, " \"\"\"\n");
- AddBody(" ");
- strings::StrAppend(&result_, "\n\n");
- return prelude_ + result_;
+ // Handle graph-mode case
+ string function_setup;
+ if (!GetEagerFunctionSetup(" ", &function_setup)) {
+ result_ = function_setup;
+ return false;
+ }
+ HandleGraphMode(function_setup);
+ AddEagerFunctionTeardown(" ", output_sizes,
+ true /* execute_record_gradient */);
+
+ // Handle eager-mode case
+ strings::StrAppend(&result_, " else:\n");
+
+ if (eager_not_allowed_error.empty()) {
+ AddEagerFastPathExecute();
+ } else {
+ strings::StrAppend(&result_, " ", eager_not_allowed_error);
+ }
+
+ strings::StrAppend(&result_, "\n\n");
+ return true;
}
-void GenPythonOp::AddExport() {
- if (api_def_.visibility() != ApiDef::VISIBLE) {
- return;
+bool GenEagerPythonOp::AddEagerFallbackCode(
+ const string& parameters, const std::vector<string>& output_sizes,
+ const string& num_outputs_expr, const string& eager_not_allowed_error) {
+ if (!eager_not_allowed_error.empty()) {
+ strings::StrAppend(&result_, " ", eager_not_allowed_error);
+ return true;
}
- strings::StrAppend(&result_, "@tf_export(");
+ AddDefLine(strings::StrCat(function_name_, kEagerFallbackSuffix),
+ strings::StrCat(parameters, ", ctx=None"));
+ strings::StrAppend(
+ &result_, " r\"\"\"This is the slowpath function for Eager mode.\n");
+ strings::StrAppend(&result_, " This is for function ", function_name_,
+ "\n \"\"\"\n");
- // Add all endpoint names to tf_export.
- bool first_endpoint = true;
- for (const auto& endpoint : api_def_.endpoint()) {
- if (!first_endpoint) {
- strings::StrAppend(&result_, ", ");
- } else {
- first_endpoint = false;
- }
- string endpoint_name;
- python_op_gen_internal::GenerateLowerCaseOpName(endpoint.name(),
- &endpoint_name);
- strings::StrAppend(&result_, "'", endpoint_name, "'");
+ strings::StrAppend(&result_, " _ctx = ctx if ctx else _context.context()\n");
+
+ string function_setup;
+ if (!GetEagerFunctionSetup(" ", &function_setup)) {
+ result_ = function_setup;
+ return false;
}
- strings::StrAppend(&result_, ")\n");
-}
+ strings::StrAppend(&result_, function_setup);
-void GenPythonOp::AddDefLine(const string& function_name,
- const string& parameters) {
- strings::StrAppend(&result_, "def ", function_name, "(", parameters, "):\n");
-}
+ AddEagerInferredAttrs(" ");
+ AddEagerInputCasts(" ");
+ strings::StrAppend(
+ &result_, " _inputs_flat = ", FlattenInputs(nullptr, nullptr), "\n");
+ AddEagerAttrs(" ");
+ AddEagerExecute(" ", num_outputs_expr);
-void GenPythonOp::AddDefLine(const string& parameters) {
- AddDefLine(function_name_, parameters);
+ AddEagerFunctionTeardown(" ", output_sizes,
+ true /* execute_record_gradient */);
+
+ return true;
}
-void GenPythonOp::AddDocStringDescription() {
- string comment;
- if (api_def_.summary().empty()) {
- comment = "TODO: add doc.\n";
- } else {
- comment = strings::StrCat(api_def_.summary(), "\n");
- if (!api_def_.description().empty()) {
- strings::StrAppend(&comment, "\n", Indent(2, 2, api_def_.description()));
- }
+void GenEagerPythonOp::AddEagerFastPathExecute() {
+ string fastpath_execute_params = strings::StrCat(
+ "_ctx._context_handle, _ctx._eager_context.device_name, \"",
+ op_def_.name(), "\", ", "name, _ctx._post_execution_callbacks");
+ string fallback_params;
+
+ for (int i = 0; i < api_def_.in_arg_size(); i++) {
+ const string param_name = param_names_[i].GetRenameTo();
+ strings::StrAppend(&fastpath_execute_params, ", ", param_name);
+ if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
+ strings::StrAppend(&fallback_params, param_name);
}
- strings::StrAppend(&result_, " r\"\"\"", comment, "\n");
-}
-void GenPythonOp::AddDocStringArgs() {
- strings::StrAppend(&result_, " Args:\n");
-}
+ for (const auto& attr : api_def_.attr()) {
+ if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
+ strings::StrAppend(&fastpath_execute_params, ", \"", attr.name(), "\", ",
+ attr.rename_to());
-void GenPythonOp::AddDocStringInputs() {
- for (int i = 0; i < api_def_.arg_order_size(); ++i) {
- const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
- const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
- StringPiece description = api_def_arg.description();
- string desc;
- if (ConsumeEquals(&description)) { // Skip the generated type info.
- desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ");
- } else {
- desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ",
- ArgTypeName(op_def_, arg, inferred_attrs_, false));
+ if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
+ strings::StrAppend(&fallback_params, attr.rename_to(), "=",
+ attr.rename_to());
}
- if (!description.empty()) {
- AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
- }
- strings::StrAppend(&result_, Indent(4, 6, desc));
}
+
+ if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
+ strings::StrAppend(&fallback_params, "name=name");
+
+ strings::StrAppend(&result_, " try:\n");
+ strings::StrAppend(
+ &result_, " ",
+ "_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n",
+ WordWrap(strings::StrCat(" "),
+ strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
+ "\n");
+
+ if (op_def_.output_arg_size() > 1) {
+ const string output_tuple_name =
+ strings::StrCat("_", op_def_.name(), "Output");
+ strings::StrAppend(&result_, " ", "_result = ", output_tuple_name,
+ "._make(_result)\n");
+ }
+ strings::StrAppend(&result_, " ", "return _result\n");
+
+ // Handle fallback.
+ if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
+ strings::StrAppend(&fallback_params, "ctx=_ctx");
+ strings::StrAppend(&result_, " ", "except _core._FallbackException:\n");
+ strings::StrAppend(
+ &result_, " ", "return ", function_name_, kEagerFallbackSuffix,
+ "(\n",
+ WordWrap(strings::StrCat(" "),
+ strings::StrCat(fallback_params, ")"), kRightMargin),
+ "\n");
+
+ // Any errors thrown from execute need to be unwrapped from
+ // _NotOkStatusException.
+ strings::StrAppend(&result_, " ",
+ "except _core._NotOkStatusException as e:\n");
+ strings::StrAppend(&result_, " ", "if name is not None:\n");
+ strings::StrAppend(&result_, " ",
+ "message = e.message + \" name: \" + name\n");
+ strings::StrAppend(&result_, " ", "else:\n");
+ strings::StrAppend(&result_, " ", "message = e.message\n");
+ strings::StrAppend(
+ &result_, " ",
+ "_six.raise_from(_core._status_to_exception(e.code, message), None)\n");
}
-void GenPythonOp::AddDocStringAttrs() {
- for (const string& name : attrs_) {
- const auto& attr = *FindAttr(name, op_def_);
- const auto& api_def_attr = *FindAttr(name, api_def_);
- string desc =
- strings::StrCat(AvoidPythonReserved(api_def_attr.rename_to()), ": ");
-
- static const char* const kAttrTypeName[][2] = {
- {"string", "`string`"},
- {"list(string)", "list of `strings`"},
- {"int", "`int`"},
- {"list(int)", "list of `ints`"},
- {"float", "`float`"},
- {"list(float)", "list of `floats`"},
- {"bool", "`bool`"},
- {"list(bool)", "list of `bools`"},
- {"type", "`tf.DType`"},
- {"list(type)", "list of `tf.DTypes`"},
- {"shape", "`tf.TensorShape` or list of `ints`"},
- {"list(shape)",
- "list of shapes (each a `tf.TensorShape` or list of `ints`)"},
- {"tensor", "`tf.TensorProto`"},
- {"list(tensor)", "list of `tf.TensorProto` objects"},
- {"func", "function decorated with @Defun"},
- {"list(func)", "list of functions decorated with @Defun"},
- };
- for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
- if (attr.type() == kAttrTypeName[i][0]) {
- string s;
- if (api_def_attr.has_default_value()) {
- s = strings::StrCat("optional ", kAttrTypeName[i][1]);
+void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) {
+ // Figure out values for inferred attrs, and cast to eager tensors.
+ for (int i = 0; i < op_def_.attr_size(); ++i) {
+ const auto& attr(op_def_.attr(i));
+ const auto& api_def_attr(api_def_.attr(i));
+ auto arg_list = attr_to_args_.find(attr.name());
+ if (arg_list != attr_to_args_.end()) {
+ if (attr.type() == "type") {
+ std::vector<string> output_sizes;
+ const string flattened =
+ FlattenInputs(&arg_list->second, &output_sizes);
+ string conversion = strings::StrCat("_execute.args_to_matching_eager(",
+ flattened, ", _ctx");
+ if (attr.has_default_value()) {
+ strings::StrAppend(
+ &conversion, ", ",
+ python_op_gen_internal::AttrValueToPython(
+ attr.type(), api_def_attr.default_value(), "_dtypes."));
+ }
+ strings::StrAppend(&conversion, ")");
+ const string var_name = AttrVarName(attr.name(), &attr_expressions_);
+ if (output_sizes.size() == 1) {
+ // Avoid creating a temporary variable in the case where
+ // we can easily assign to the right value directly.
+ const string inputs_var =
+ param_names_[arg_list->second.front()].GetRenameTo();
+ if (output_sizes.front().empty()) {
+ strings::StrAppend(&result_, indentation, var_name, ", (",
+ inputs_var, ",) = ", conversion, "\n");
+ } else {
+ strings::StrAppend(&result_, indentation, var_name, ", ",
+ inputs_var, " = ", conversion, "\n");
+ }
} else {
- s = kAttrTypeName[i][1];
+ const string inputs_var = strings::StrCat("_inputs_", attr.name());
+ strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
+ " = ", conversion, "\n");
+ // Convert from a flat list of eager tensors back to the
+ // parameter variables.
+ Unflatten(indentation, output_sizes, inputs_var, &result_);
+ std::vector<string> p;
+ for (int j : arg_list->second) {
+ p.emplace_back(param_names_[j].GetRenameTo());
+ }
+ strings::StrAppend(&result_, indentation, VectorToTuple(p), " = ",
+ inputs_var, "\n");
}
- if (s[0] == 'o' || (s[0] == '`' && (s[1] == 'i' || s[1] == 'o'))) {
- strings::StrAppend(&desc, "An ", s);
+ } else if (attr.type() == "list(type)") {
+ // NOTE: We ignore default values for these attrs, since it is
+ // unclear how you would use it, and the one use case is
+ // parse_single_sequence_example which only needs it for
+ // backwards compatibility.
+ const string var_name = AttrVarName(attr.name(), &attr_expressions_);
+ string inputs_var;
+ string conversion;
+ if (arg_list->second.size() > 1) {
+ // If you have more than one list(tensor) argument, their types
+ // have to match.
+ std::vector<string> lists;
+ for (auto iter = arg_list->second.begin();
+ iter != arg_list->second.end(); ++iter) {
+ lists.push_back(param_names_[*iter].GetRenameTo());
+ }
+ inputs_var = VectorToTuple(lists);
+ conversion = "_execute.args_to_mixed_eager_tensors";
} else {
- strings::StrAppend(&desc, "A ", s);
+ // For one list(tensor) argument, we just convert every
+ // element of the list to an eager tensor.
+ inputs_var = param_names_[arg_list->second.front()].GetRenameTo();
+ conversion = "_execute.convert_to_mixed_eager_tensors";
}
- break;
+ strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
+ " = ", conversion, "(", inputs_var, ", _ctx)\n");
}
}
-
- if (attr.has_allowed_values()) {
- strings::StrAppend(&desc, " from: `",
- AttrListToPython(attr.allowed_values()), "`");
- }
-
- if (attr.has_minimum()) {
- if (attr.type() == "int") {
- strings::StrAppend(&desc, " that is `>= ", attr.minimum(), "`");
- } else if (attr.minimum() > 0) {
- strings::StrAppend(&desc, " that has length `>= ", attr.minimum(), "`");
- }
- }
-
- strings::StrAppend(&desc, ".");
-
- if (api_def_attr.has_default_value()) {
- strings::StrAppend(
- &desc, " Defaults to `",
- AttrValueToPython(attr.type(), api_def_attr.default_value()), "`.");
- }
- if (!api_def_attr.description().empty()) {
- AppendWithinWidth(&desc, api_def_attr.description(),
- kRightMargin - 4 /* indent */);
- }
- strings::StrAppend(&result_, Indent(4, 6, desc));
}
}
-void GenPythonOp::AddDocStringNameArg() {
- strings::StrAppend(&result_,
- " name: A name for the operation (optional).\n");
+void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) {
+ // Cast remaining args to eager tensors
+ for (int i = 0; i < op_def_.input_arg_size(); ++i) {
+ const auto& arg(op_def_.input_arg(i));
+ if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
+ const string& param = param_names_[i].GetRenameTo();
+ const string fn = arg.number_attr().empty() ? "" : "n_";
+ const string dtype =
+ python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
+ strings::StrAppend(&result_, indentation, param, " = _ops.convert_", fn,
+ "to_tensor(", param, ", ", dtype, ")\n");
+ }
}
-void GenPythonOp::AddOutputGlobals() {
- // Prepare a NamedTuple type to hold the outputs, if there are multiple
- if (num_outs_ > 1) {
- // Prepare the list of output names
- std::vector<string> out_names(num_outs_);
- for (int i = 0; i < num_outs_; ++i) {
- if (!api_def_.out_arg(i).rename_to().empty()) {
- out_names[i] = api_def_.out_arg(i).rename_to();
- } else {
- out_names[i] = strings::StrCat("output", i);
- }
+void GenEagerPythonOp::AddEagerAttrs(const string& indentation) {
+ // Compute eager attrs
+ if (op_def_.attr_size() > 0) {
+ string attr_values;
+ for (int i = 0; i < op_def_.attr_size(); ++i) {
+ if (i > 0) strings::StrAppend(&attr_values, ", ");
+ const auto& attr_name(op_def_.attr(i).name());
+ strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
+ attr_expressions_[attr_name]);
}
- string out_names_list =
- strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]");
-
- // Provide the output names as a Python list
- string lower_op_name_outputs =
- strings::StrCat("_", function_name_, "_outputs");
- const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = ");
- strings::StrAppend(&prelude_, "\n",
- WordWrap(outputs_prefix, out_names_list, kRightMargin),
- "\n");
-
- strings::StrAppend(&prelude_, "_", op_def_.name(),
- "Output = _collections.namedtuple(\n");
- const string tuple_type_prefix = " ";
- const string tuple_type_suffix = strings::StrCat(
- "\"", op_def_.name(), "\", ", lower_op_name_outputs, ")");
+ strings::StrAppend(&attr_values, ")");
strings::StrAppend(
- &prelude_, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin),
- "\n\n");
- }
- strings::StrAppend(&prelude_, "\n");
-}
-
-void GenPythonOp::AddDocStringOutputs() {
- std::vector<string> output_type_string;
- output_type_string.reserve(num_outs_);
- for (int i = 0; i < num_outs_; ++i) {
- output_type_string.push_back(
- ArgTypeName(op_def_, op_def_.output_arg(i), inferred_attrs_, true));
- }
- strings::StrAppend(&result_, GetReturns(op_def_, output_type_string));
-}
-
-void GenPythonOp::AddBody(const string& prefix) {
- const string apply_prefix =
- strings::StrCat(prefix, "_result = _op_def_lib.apply_op(");
- AddBodyNoReturn(apply_prefix);
- if (num_outs_ > 1) {
- strings::StrAppend(&result_, prefix, "_result = _", op_def_.name(),
- "Output._make(_result)\n");
+ &result_,
+ WordWrap(indentation, strings::StrCat("_attrs = (", attr_values),
+ kRightMargin),
+ "\n");
+ } else {
+ strings::StrAppend(&result_, indentation, "_attrs = None\n");
}
- strings::StrAppend(&result_, prefix, "return _result\n");
}
-void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) {
- string args = strings::StrCat("\"", op_def_.name(), "\", ");
- for (size_t i = 0; i < param_names_.size(); ++i) {
- strings::StrAppend(&args, AvoidPythonReserved(param_names_[i].GetName()),
- "=", param_names_[i].GetRenameTo(), ", ");
- }
- strings::StrAppend(&args, "name=name)");
-
+void GenEagerPythonOp::AddEagerExecute(const string& indentation,
+ const string& num_outputs_expr) {
+ const string return_prefix =
+ strings::StrCat(indentation, "_result = _execute.execute(");
+ const string return_args = strings::StrCat(
+ "b\"", op_def_.name(), "\", ", num_outputs_expr,
+ ", inputs=_inputs_flat, attrs=_attrs, ctx=_ctx, name=name)");
strings::StrAppend(&result_,
// Wrap the arguments, and indent to the (.
- WordWrap(apply_prefix, args, kRightMargin), "\n");
-}
-
-} // namespace python_op_gen_internal
-
-string GetPythonOp(const OpDef& op_def, const ApiDef& api_def,
- const string& function_name) {
- return python_op_gen_internal::GenPythonOp(op_def, api_def, function_name)
- .Code();
+ WordWrap(return_prefix, return_args, kRightMargin), "\n");
}
string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
- const std::vector<string>& hidden_ops,
- bool require_shapes) {
+ const std::vector<string>& hidden_ops, bool require_shapes,
+ const string& source_file_name = "") {
string result;
// Header
// TODO(josh11b): Mention the library for which wrappers are being generated.
strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
This file is MACHINE GENERATED! Do not edit.
-"""
+)");
+
+ // Mention the original source file so someone tracing back through
+ // generated Python code will know where to look next.
+ if (!source_file_name.empty()) {
+ strings::StrAppend(&result, "Original C++ source file: ");
+ strings::StrAppend(&result, source_file_name);
+ strings::StrAppend(&result, "\n");
+ }
+
+ strings::StrAppend(&result, R"("""
import collections as _collections
+import six as _six
-from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
+from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
+from tensorflow.python.eager import context as _context
+from tensorflow.python.eager import core as _core
+from tensorflow.python.eager import execute as _execute
+from tensorflow.python.framework import dtypes as _dtypes
+from tensorflow.python.framework import errors as _errors
+from tensorflow.python.framework import tensor_shape as _tensor_shape
+from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
# Needed to trigger the call to _set_call_cpp_shape_fn.
from tensorflow.python.framework import common_shapes as _common_shapes
-
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
from tensorflow.python.util.tf_export import tf_export
+
)");
// We'll make a copy of ops that filters out descriptions.
@@ -839,7 +957,6 @@ from tensorflow.python.util.tf_export import tf_export
if (api_def->visibility() == ApiDef::SKIP) {
continue;
}
-
// An op is hidden if either its ApiDef visibility is HIDDEN
// or it is in the hidden_ops list.
bool is_hidden = api_def->visibility() == ApiDef::HIDDEN;
@@ -875,11 +992,12 @@ from tensorflow.python.util.tf_export import tf_export
continue;
}
- strings::StrAppend(&result, GetPythonOp(op_def, *api_def, function_name));
+ strings::StrAppend(&result,
+ GetEagerPythonOp(op_def, *api_def, function_name));
if (!require_shapes) {
strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
- "\")(None)\n");
+ "\")(None)\n\n");
}
auto added = out->Add();
@@ -894,8 +1012,6 @@ from tensorflow.python.util.tf_export import tf_export
op_def_lib = _op_def_library.OpDefLibrary()
op_def_lib.add_op_list(op_list)
return op_def_lib
-
-
)");
result.append("# ");
@@ -908,16 +1024,21 @@ from tensorflow.python.util.tf_export import tf_export
return result;
}
+} // namespace
+
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
- const std::vector<string>& hidden_ops,
- bool require_shapes) {
- printf("%s", GetPythonOps(ops, api_defs, hidden_ops, require_shapes).c_str());
+ const std::vector<string>& hidden_ops, bool require_shapes,
+ const string& source_file_name) {
+ printf("%s", GetPythonOps(ops, api_defs, hidden_ops, require_shapes,
+ source_file_name)
+ .c_str());
}
string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
string op_list_str(op_list_buf, op_list_len);
OpList ops;
ops.ParseFromString(op_list_str);
+
ApiDefMap api_def_map(ops);
return GetPythonOps(ops, api_def_map, {}, false);
}
diff --git a/tensorflow/python/framework/python_op_gen.h b/tensorflow/python/framework/python_op_gen.h
index 4d20888dc6..7e754fd122 100644
--- a/tensorflow/python/framework/python_op_gen.h
+++ b/tensorflow/python/framework/python_op_gen.h
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,29 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_
#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_
#include <string>
#include <vector>
-#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-// hidden_ops should be a vector of Op names that should get a leading _ in the
-// output.
-// The Print* version prints the output to stdout, Get* version returns the
-// output as a string.
+// hidden_ops should be a list of Op names that should get a leading _
+// in the output. Prints the output to stdout.
+// Optional fourth argument is the name of the original C++ source file
+// where the ops' REGISTER_OP() calls reside.
void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
- const std::vector<string>& hidden_ops, bool require_shapes);
-string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
- const std::vector<string>& hidden_ops, bool require_shapes);
-string GetPythonOp(const OpDef& op_def, const ApiDef& api_def,
- const string& function_name);
+ const std::vector<string>& hidden_ops, bool require_shapes,
+ const string& source_file_name = "");
// Get the python wrappers for a list of ops in a OpList.
// `op_list_buf` should be a pointer to a buffer containing
diff --git a/tensorflow/python/framework/python_op_gen.i b/tensorflow/python/framework/python_op_gen.i
index efcce2f209..26ec4e8e66 100644
--- a/tensorflow/python/framework/python_op_gen.i
+++ b/tensorflow/python/framework/python_op_gen.i
@@ -16,10 +16,10 @@ limitations under the License.
%include "tensorflow/python/platform/base.i"
%{
-#include "tensorflow/python/eager/python_eager_op_gen.h"
+#include "tensorflow/python/framework/python_op_gen.h"
%}
-// Input typemap for GetEagerPythonWrappers.
+// Input typemap for GetPythonWrappers.
// Accepts a python object of 'bytes' type, and converts it to
// a const char* pointer and size_t length. The default typemap
// going from python bytes to const char* tries to decode the
@@ -37,5 +37,5 @@ limitations under the License.
%ignoreall;
-%unignore tensorflow::GetEagerPythonWrappers;
-%include "tensorflow/python/eager/python_eager_op_gen.h"
+%unignore tensorflow::GetPythonWrappers;
+%include "tensorflow/python/framework/python_op_gen.h"
diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc
new file mode 100644
index 0000000000..940bffb906
--- /dev/null
+++ b/tensorflow/python/framework/python_op_gen_internal.cc
@@ -0,0 +1,800 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/python/framework/python_op_gen_internal.h"
+
+#include <stdio.h>
+#include <sstream>
+#include <unordered_map>
+#include "tensorflow/core/framework/api_def.pb.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb_text.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/framework/tensor.pb_text.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace python_op_gen_internal {
+
+const int kRightMargin = 78;
+
+bool IsPythonReserved(const string& s) {
+ static const std::set<string>* const kPythonReserved = new std::set<string>(
+ {// Keywords in Python, from:
+ // import keyword
+ // print keyword.kwlist
+ "and", "as", "assert", "break", "class", "continue", "def", "del",
+ "elif", "else", "except", "exec", "finally", "for", "from", "global",
+ "if", "import", "in", "is", "lambda", "not", "or", "pass", "print",
+ "raise", "return", "try", "while", "with", "yield",
+ // Built-in functions and types in Python, from:
+ // [x for x in dir(__builtins__) if not x[0].islower()]
+ "ArithmeticError", "AssertionError", "AttributeError", "BaseException",
+ "BufferError", "BytesWarning", "DeprecationWarning", "EOFError",
+ "Ellipsis", "EnvironmentError", "Exception", "False",
+ "FloatingPointError", "FutureWarning", "GeneratorExit", "IOError",
+ "ImportError", "ImportWarning", "IndentationError", "IndexError",
+ "KeyError", "KeyboardInterrupt", "LookupError", "MemoryError",
+ "NameError", "None", "NotImplemented", "NotImplementedError", "OSError",
+ "OverflowError", "PendingDeprecationWarning", "ReferenceError",
+ "RuntimeError", "RuntimeWarning", "StandardError", "StopIteration",
+ "SyntaxError", "SyntaxWarning", "SystemError", "SystemExit", "TabError",
+ "True", "TypeError", "UnboundLocalError", "UnicodeDecodeError",
+ "UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError",
+ "UnicodeWarning", "UserWarning", "ValueError", "Warning",
+ "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__",
+ "__package__"});
+
+ return kPythonReserved->count(s) > 0;
+}
+
+bool IsOpWithUnderscorePrefix(const string& s) {
+ static const std::set<string>* const kUnderscoreOps = new std::set<string>(
+ {// Lowercase built-in functions and types in Python, from:
+ // [x for x in dir(__builtins__) if x[0].islower()] except "round".
+ // These need to be excluded so they don't conflict with actual built-in
+ // functions since we use '*' imports.
+ "abs", "all", "any", "apply", "bin", "bool", "buffer", "bytearray",
+ "bytes", "callable", "chr", "classmethod", "cmp", "coerce", "compile",
+ "complex", "copyright", "credits", "delattr", "dict", "dir", "divmod",
+ "enumerate", "eval", "execfile", "exit", "file", "filter", "float",
+ "format", "frozenset", "getattr", "globals", "hasattr", "hash", "help",
+ "hex", "id", "input", "int", "intern", "isinstance", "issubclass",
+ "iter", "len", "license", "list", "locals", "long", "map", "max",
+ "memoryview", "min", "next", "object", "oct", "open", "ord", "pow",
+ "print", "property", "quit", "range", "raw_input", "reduce", "reload",
+ "repr", "reversed", "set", "setattr", "slice", "sorted", "staticmethod",
+ "str", "sum", "super", "tuple", "type", "unichr", "unicode", "vars",
+ "xrange", "zip",
+ // These have the same name as ops defined in Python and might be used
+ // incorrectly depending on order of '*' imports.
+ // TODO(annarev): reduce usage of '*' imports and remove these from the
+ // list.
+ "fused_batch_norm", "histogram_fixed_width", "stack",
+ "batch_norm_with_global_normalization", "clip_by_value"});
+ return kUnderscoreOps->count(s) > 0;
+}
+
+string AvoidPythonReserved(const string& s) {
+ if (IsPythonReserved(s)) return strings::StrCat(s, "_");
+ return s;
+}
+
+// Indent the first line by "initial" spaces and all following lines
+// by "rest" spaces.
+string Indent(int initial, int rest, StringPiece in) {
+ // TODO(josh11b): Also word-wrapping?
+ string copy(in.data(), in.size());
+ str_util::StripTrailingWhitespace(&copy);
+ std::vector<string> v = str_util::Split(copy, '\n');
+
+ string result;
+ bool first = true;
+ for (const string& line : v) {
+ if (first) {
+ result = strings::StrCat(Spaces(initial), line, "\n");
+ first = false;
+ } else {
+ if (line.empty()) {
+ strings::StrAppend(&result, "\n");
+ } else {
+ strings::StrAppend(&result, Spaces(rest), line, "\n");
+ }
+ }
+ }
+ return result;
+}
+
+// Adds append to *dest, with a space if the first line will be <= width,
+// or a newline otherwise.
+void AppendWithinWidth(string* dest, StringPiece append, int width) {
+ auto first_line = append.find('\n');
+ if (first_line == string::npos) first_line = append.size();
+ if (dest->size() + first_line + 1 /* space */ > static_cast<size_t>(width)) {
+ strings::StrAppend(dest, "\n", append);
+ } else {
+ strings::StrAppend(dest, " ", append);
+ }
+}
+
+// Like DataTypeString() but uses the Python names for the
+// float types.
+string PythonDataTypeString(DataType dtype) {
+ switch (dtype) {
+ case DT_FLOAT:
+ return "float32";
+ case DT_DOUBLE:
+ return "float64";
+ default:
+ return DataTypeString(dtype);
+ }
+}
+
+string TypeString(DataType dtype, bool ref) {
+ if (ref) {
+ return strings::StrCat("mutable `", PythonDataTypeString(dtype), "`");
+ } else {
+ return strings::StrCat("`", PythonDataTypeString(dtype), "`");
+ }
+}
+
+string TypeListString(const AttrValue& value) {
+ string ret;
+ for (int t : value.list().type()) {
+ if (!ret.empty()) strings::StrAppend(&ret, ", ");
+ DataType dtype = static_cast<DataType>(t);
+ if (IsRefType(dtype)) {
+ strings::StrAppend(&ret, PythonDataTypeString(RemoveRefType(dtype)),
+ " mutable");
+ } else {
+ strings::StrAppend(&ret, "`", PythonDataTypeString(dtype), "`");
+ }
+ }
+ return ret;
+}
+
+string SingleTensorName(DataType dtype, bool is_ref) {
+ const string type_str = TypeString(dtype, is_ref);
+ return strings::StrCat("A `Tensor` of type ", type_str, ".");
+}
+
+const char kUnknownTensorType[] = {"A `Tensor`."};
+
+string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg,
+ const std::unordered_map<string, string>& inferred_attrs,
+ bool is_output) {
+ if (!arg.number_attr().empty()) {
+ // N Tensors with the same type
+ const string* original_arg =
+ gtl::FindOrNull(inferred_attrs, arg.number_attr());
+ string prefix;
+ if (original_arg == nullptr) {
+ prefix = strings::StrCat("A list of `", arg.number_attr(), "`");
+ } else if (*original_arg == arg.name()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
+ if (attr->has_minimum() && attr->minimum() > 0) {
+ prefix = strings::StrCat("A list of at least ", attr->minimum());
+ } else {
+ prefix = "A list of";
+ }
+ } else {
+ prefix = strings::StrCat("A list with the same length as `",
+ AvoidPythonReserved(*original_arg), "` of");
+ }
+
+ if (arg.type() != DT_INVALID) {
+ return strings::StrCat(prefix, " `Tensor` objects with type ",
+ TypeString(arg.type(), arg.is_ref()), ".");
+ } else {
+ original_arg = gtl::FindOrNull(inferred_attrs, arg.type_attr());
+ if (arg.is_ref()) {
+ strings::StrAppend(&prefix, " mutable");
+ }
+ if (original_arg == nullptr) {
+ return strings::StrCat(prefix, " `Tensor` objects with type `",
+ arg.type_attr(), "`.");
+ } else if (*original_arg == arg.name()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
+ if (attr->has_allowed_values()) {
+ return strings::StrCat(prefix,
+ " `Tensor` objects with the same type in: ",
+ TypeListString(attr->allowed_values()), ".");
+ } else {
+ return strings::StrCat(prefix,
+ " `Tensor` objects with the same type.");
+ }
+ } else {
+ return strings::StrCat(prefix,
+ " `Tensor` objects with the same type as `",
+ AvoidPythonReserved(*original_arg), "`.");
+ }
+ }
+ } else if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) {
+ const bool is_list = !arg.type_list_attr().empty();
+ const string attr_name = is_list ? arg.type_list_attr() : arg.type_attr();
+ const OpDef::AttrDef* attr = FindAttr(attr_name, op_def);
+ const string mutable_str = arg.is_ref() ? "mutable " : "";
+ const string prefix =
+ is_list ? strings::StrCat("A list of ", mutable_str, "`Tensor` objects")
+ : strings::StrCat("A ", mutable_str, "`Tensor`");
+ const string* original_arg = gtl::FindOrNull(inferred_attrs, attr_name);
+ if (original_arg == nullptr) {
+ return strings::StrCat(prefix, " of type `", attr_name, "`.");
+ } else if (*original_arg == arg.name()) {
+ if (attr->has_allowed_values()) {
+ if (is_list) {
+ return strings::StrCat(prefix, " with types from: ",
+ TypeListString(attr->allowed_values()), ".");
+ } else {
+ return strings::StrCat(
+ prefix, is_output ? ". Has one of the following types: "
+ : ". Must be one of the following types: ",
+ TypeListString(attr->allowed_values()), ".");
+ }
+ } else {
+ return strings::StrCat(prefix, ".");
+ }
+ } else {
+ return strings::StrCat(prefix,
+ is_output ? ". Has the same type as `"
+ : ". Must have the same type as `",
+ AvoidPythonReserved(*original_arg), "`.");
+ }
+ } else {
+ return SingleTensorName(arg.type(), arg.is_ref());
+ }
+}
+
+string GetReturns(const OpDef& op_def,
+ const std::vector<string>& output_type_string) {
+ string result;
+ DCHECK_EQ(op_def.output_arg_size(), output_type_string.size());
+ const int num_outs = op_def.output_arg_size();
+ strings::StrAppend(&result, "\n Returns:\n");
+ if (num_outs == 0) {
+ strings::StrAppend(&result, " The created Operation.\n");
+ } else {
+ if (num_outs == 1) {
+ StringPiece description = op_def.output_arg(0).description();
+ if (ConsumeEquals(&description)) { // Skip the generated type info.
+ strings::StrAppend(&result, Indent(4, 4, description));
+ } else {
+ // Special case of one output, don't use the name of the output unless
+ // there is no description.
+ string desc = output_type_string.empty() ? kUnknownTensorType
+ : output_type_string[0];
+ if (desc == kUnknownTensorType) {
+ // Special case where we don't understand how the output tensor type
+ // depends on the input tensor types, just use the output arg
+ // description if we can.
+ if (!description.empty()) {
+ desc = op_def.output_arg(0).description();
+ } else if (!op_def.output_arg(0).name().empty()) {
+ desc = strings::StrCat(" The ", op_def.output_arg(0).name(),
+ " `Tensor`.");
+ }
+ } else if (!description.empty()) {
+ AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
+ }
+ strings::StrAppend(&result, Indent(4, 4, desc));
+ }
+ } else {
+ std::vector<string> out_names(num_outs);
+ for (int i = 0; i < num_outs; ++i) {
+ if (!op_def.output_arg(i).name().empty()) {
+ out_names[i] = op_def.output_arg(i).name();
+ } else {
+ out_names[i] = strings::StrCat("output", i);
+ }
+ }
+ strings::StrAppend(&result, " A tuple of `Tensor` objects (",
+ str_util::Join(out_names, ", "), ").\n\n");
+ for (int i = 0; i < num_outs; ++i) {
+ string desc = strings::StrCat(out_names[i], ": ");
+ StringPiece description = op_def.output_arg(i).description();
+ if (ConsumeEquals(&description)) { // Skip the generated type info.
+ strings::StrAppend(&desc, description);
+ } else {
+ const string type = static_cast<size_t>(i) < output_type_string.size()
+ ? output_type_string[i]
+ : kUnknownTensorType;
+ if (!description.empty()) {
+ if (type == kUnknownTensorType) {
+ // Special case where we don't understand how the output tensor
+ // type depends on the input tensor types, so we just use the
+ // output arg description.
+ strings::StrAppend(&desc, description);
+ } else {
+ strings::StrAppend(&desc, type, " ", description);
+ }
+ } else {
+ strings::StrAppend(&desc, type);
+ }
+ }
+ strings::StrAppend(&result, Indent(4, 6, desc));
+ }
+ }
+ }
+ return result;
+}
+
+string StringToPython(const string& str) {
+ return strings::StrCat("\"", str_util::CEscape(str), "\"");
+}
+
+string DataTypeToPython(DataType dtype, const string& dtype_module) {
+ return strings::StrCat(dtype_module, PythonDataTypeString(dtype));
+}
+
+string ShapeToPython(const TensorShapeProto& shape) {
+ if (shape.unknown_rank()) {
+ return "None";
+ }
+ string python = "[";
+ for (const auto& dim : shape.dim()) {
+ if (python.size() > 1) strings::StrAppend(&python, ", ");
+ if (!dim.name().empty()) {
+ strings::StrAppend(&python, "(", StringToPython(dim.name()), ", ",
+ dim.size(), ")");
+ } else {
+ strings::StrAppend(&python, dim.size());
+ }
+ }
+ strings::StrAppend(&python, "]");
+ return python;
+}
+
+string TensorToPython(const TensorProto& proto) {
+ return ProtoShortDebugString(proto);
+}
+
+string AttrListToPython(const AttrValue& value,
+ const string& dtype_module = "tf.") {
+ string ret;
+ if (value.list().s_size() > 0) {
+ for (int i = 0; i < value.list().s_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, StringToPython(value.list().s(i)));
+ }
+ } else if (value.list().i_size() > 0) {
+ for (int i = 0; i < value.list().i_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, value.list().i(i));
+ }
+ } else if (value.list().f_size() > 0) {
+ for (int i = 0; i < value.list().f_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, value.list().f(i));
+ }
+ } else if (value.list().b_size() > 0) {
+ for (int i = 0; i < value.list().b_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, value.list().b(i) ? "True" : "False");
+ }
+ } else if (value.list().type_size() > 0) {
+ for (int i = 0; i < value.list().type_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret,
+ DataTypeToPython(value.list().type(i), dtype_module));
+ }
+ } else if (value.list().shape_size() > 0) {
+ for (int i = 0; i < value.list().shape_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, ShapeToPython(value.list().shape(i)));
+ }
+ } else if (value.list().tensor_size() > 0) {
+ for (int i = 0; i < value.list().tensor_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, TensorToPython(value.list().tensor(i)));
+ }
+ } else if (value.list().func_size() > 0) {
+ for (int i = 0; i < value.list().func_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, StringToPython(value.list().func(i).name()));
+ }
+ }
+ return ret;
+}
+
+// NOTE: The return value may contain spaces (for example, it could be
+// a string "foo bar" with an embedded space) and is not safe to pass
+// to WordWrap().
+string AttrValueToPython(const string& type, const AttrValue& value,
+ const string& dtype_module) {
+ if (type == "string") {
+ return StringToPython(value.s());
+ } else if (type == "int") {
+ return strings::StrCat(value.i());
+ } else if (type == "float") {
+ if (std::isnan(value.f()) || std::isinf(value.f())) {
+ return strings::StrCat("float('", value.f(), "')");
+ } else {
+ return strings::StrCat(value.f());
+ }
+ } else if (type == "bool") {
+ return value.b() ? "True" : "False";
+ } else if (type == "type") {
+ return DataTypeToPython(value.type(), dtype_module);
+ } else if (type == "shape") {
+ return ShapeToPython(value.shape());
+ } else if (type == "tensor") {
+ return TensorToPython(value.tensor());
+ } else if (type == "func") {
+ return StringToPython(value.func().name());
+ } else if (str_util::StartsWith(type, "list(")) {
+ return strings::StrCat("[", AttrListToPython(value, dtype_module), "]");
+ } else {
+ return "?";
+ }
+}
+
+void GenerateLowerCaseOpName(const string& str, string* result) {
+ const char joiner = '_';
+ const int last_index = str.size() - 1;
+ for (int i = 0; i <= last_index; ++i) {
+ const char c = str[i];
+ // Emit a joiner only if a previous-lower-to-now-upper or a
+ // now-upper-to-next-lower transition happens.
+ if (isupper(c) && (i > 0)) {
+ if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) {
+ result->push_back(joiner);
+ }
+ }
+ result->push_back(tolower(c));
+ }
+}
+
+static void AddDelimiter(string* append_to, const string& delim) {
+ if (!append_to->empty()) strings::StrAppend(append_to, delim);
+}
+
+const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) {
+ for (int i = 0; i < api_def.attr_size(); ++i) {
+ if (api_def.attr(i).name() == name) {
+ return &api_def.attr(i);
+ }
+ }
+ return nullptr;
+}
+
+const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
+ for (int i = 0; i < api_def.in_arg_size(); ++i) {
+ if (api_def.in_arg(i).name() == name) {
+ return &api_def.in_arg(i);
+ }
+ }
+ return nullptr;
+}
+
+GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
+ const string& function_name)
+ : op_def_(op_def),
+ api_def_(api_def),
+ function_name_(function_name),
+ num_outs_(op_def.output_arg_size()) {}
+
+GenPythonOp::~GenPythonOp() {}
+
+string GenPythonOp::Code() {
+ // This has all the input args followed by those attrs that don't have
+ // defaults.
+ std::vector<ParamNames> params_no_default;
+ // The parameters with defaults (these have to be listed after those without).
+ // No input args are included, just attrs.
+ std::vector<ParamNames> params_with_default;
+
+ for (int i = 0; i < api_def_.arg_order_size(); ++i) {
+ const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
+ const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
+ params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to());
+ if (!arg.type_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name());
+ } else if (!arg.type_list_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_list_attr(),
+ arg.name());
+ }
+ if (!arg.number_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name());
+ }
+ }
+ for (int i = 0; i < api_def_.attr_size(); ++i) {
+ const auto& attr(api_def_.attr(i));
+ // Do not add inferred attrs to the Python function signature.
+ if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
+ if (attr.has_default_value()) {
+ params_with_default.emplace_back(attr.name(), attr.rename_to());
+ } else {
+ params_no_default.emplace_back(attr.name(), attr.rename_to());
+ }
+ }
+ }
+
+ // Save the list of attr parameters (attrs that won't be inferred),
+ // those with defaults go at the end.
+ // Get the attrs in the order we want by taking the attrs without defaults
+ // from the end of args_no_default, and adding args_no_default.
+ attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() +
+ params_with_default.size());
+ for (int i = op_def_.input_arg_size(); i < params_no_default.size(); ++i) {
+ attrs_.push_back(params_no_default[i].GetName());
+ }
+ for (int i = 0; i < params_with_default.size(); ++i) {
+ attrs_.push_back(params_with_default[i].GetName());
+ }
+
+ param_names_.reserve(params_no_default.size() + params_with_default.size());
+ param_names_.insert(param_names_.begin(), params_no_default.begin(),
+ params_no_default.end());
+ for (const auto& param : params_with_default) {
+ param_names_.push_back(param);
+ }
+
+ string parameters;
+ for (const auto& param : params_no_default) {
+ AddDelimiter(&parameters, ", ");
+ strings::StrAppend(&parameters, param.GetRenameTo());
+ }
+ for (const auto& param_and_default : params_with_default) {
+ AddDelimiter(&parameters, ", ");
+ strings::StrAppend(&parameters, param_and_default.GetRenameTo(), "=None");
+ }
+ AddDelimiter(&parameters, ", ");
+ strings::StrAppend(&parameters, "name=None");
+
+ AddExport();
+ AddDefLine(parameters);
+ AddDocStringDescription();
+ AddDocStringArgs();
+ AddDocStringInputs();
+ AddDocStringAttrs();
+ AddDocStringNameArg();
+ AddOutputGlobals();
+ AddDocStringOutputs();
+ strings::StrAppend(&result_, " \"\"\"\n");
+ AddBody(" ");
+ strings::StrAppend(&result_, "\n\n");
+
+ return prelude_ + result_;
+}
+
+void GenPythonOp::AddExport() {
+ if (api_def_.visibility() != ApiDef::VISIBLE) {
+ return;
+ }
+
+ strings::StrAppend(&result_, "@tf_export(");
+
+ // Add all endpoint names to tf_export.
+ bool first_endpoint = true;
+ for (const auto& endpoint : api_def_.endpoint()) {
+ if (!first_endpoint) {
+ strings::StrAppend(&result_, ", ");
+ } else {
+ first_endpoint = false;
+ }
+ string endpoint_name;
+ python_op_gen_internal::GenerateLowerCaseOpName(endpoint.name(),
+ &endpoint_name);
+ strings::StrAppend(&result_, "'", endpoint_name, "'");
+ }
+ strings::StrAppend(&result_, ")\n");
+}
+
+void GenPythonOp::AddDefLine(const string& function_name,
+ const string& parameters) {
+ strings::StrAppend(&result_, "def ", function_name, "(", parameters, "):\n");
+}
+
+void GenPythonOp::AddDefLine(const string& parameters) {
+ AddDefLine(function_name_, parameters);
+}
+
+void GenPythonOp::AddDocStringDescription() {
+ string comment;
+ if (api_def_.summary().empty()) {
+ comment = "TODO: add doc.\n";
+ } else {
+ comment = strings::StrCat(api_def_.summary(), "\n");
+ if (!api_def_.description().empty()) {
+ strings::StrAppend(&comment, "\n", Indent(2, 2, api_def_.description()));
+ }
+ }
+ strings::StrAppend(&result_, " r\"\"\"", comment, "\n");
+}
+
+void GenPythonOp::AddDocStringArgs() {
+ strings::StrAppend(&result_, " Args:\n");
+}
+
+void GenPythonOp::AddDocStringInputs() {
+ for (int i = 0; i < api_def_.arg_order_size(); ++i) {
+ const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
+ const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
+ StringPiece description = api_def_arg.description();
+ string desc;
+ if (ConsumeEquals(&description)) { // Skip the generated type info.
+ desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ");
+ } else {
+ desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ",
+ ArgTypeName(op_def_, arg, inferred_attrs_, false));
+ }
+ if (!description.empty()) {
+ AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
+ }
+ strings::StrAppend(&result_, Indent(4, 6, desc));
+ }
+}
+
+void GenPythonOp::AddDocStringAttrs() {
+ for (const string& name : attrs_) {
+ const auto& attr = *FindAttr(name, op_def_);
+ const auto& api_def_attr = *FindAttr(name, api_def_);
+ string desc =
+ strings::StrCat(AvoidPythonReserved(api_def_attr.rename_to()), ": ");
+
+ static const char* const kAttrTypeName[][2] = {
+ {"string", "`string`"},
+ {"list(string)", "list of `strings`"},
+ {"int", "`int`"},
+ {"list(int)", "list of `ints`"},
+ {"float", "`float`"},
+ {"list(float)", "list of `floats`"},
+ {"bool", "`bool`"},
+ {"list(bool)", "list of `bools`"},
+ {"type", "`tf.DType`"},
+ {"list(type)", "list of `tf.DTypes`"},
+ {"shape", "`tf.TensorShape` or list of `ints`"},
+ {"list(shape)",
+ "list of shapes (each a `tf.TensorShape` or list of `ints`)"},
+ {"tensor", "`tf.TensorProto`"},
+ {"list(tensor)", "list of `tf.TensorProto` objects"},
+ {"func", "function decorated with @Defun"},
+ {"list(func)", "list of functions decorated with @Defun"},
+ };
+ for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
+ if (attr.type() == kAttrTypeName[i][0]) {
+ string s;
+ if (api_def_attr.has_default_value()) {
+ s = strings::StrCat("optional ", kAttrTypeName[i][1]);
+ } else {
+ s = kAttrTypeName[i][1];
+ }
+ if (s[0] == 'o' || (s[0] == '`' && (s[1] == 'i' || s[1] == 'o'))) {
+ strings::StrAppend(&desc, "An ", s);
+ } else {
+ strings::StrAppend(&desc, "A ", s);
+ }
+ break;
+ }
+ }
+
+ if (attr.has_allowed_values()) {
+ strings::StrAppend(&desc, " from: `",
+ AttrListToPython(attr.allowed_values()), "`");
+ }
+
+ if (attr.has_minimum()) {
+ if (attr.type() == "int") {
+ strings::StrAppend(&desc, " that is `>= ", attr.minimum(), "`");
+ } else if (attr.minimum() > 0) {
+ strings::StrAppend(&desc, " that has length `>= ", attr.minimum(), "`");
+ }
+ }
+
+ strings::StrAppend(&desc, ".");
+
+ if (api_def_attr.has_default_value()) {
+ strings::StrAppend(
+ &desc, " Defaults to `",
+ AttrValueToPython(attr.type(), api_def_attr.default_value()), "`.");
+ }
+ if (!api_def_attr.description().empty()) {
+ AppendWithinWidth(&desc, api_def_attr.description(),
+ kRightMargin - 4 /* indent */);
+ }
+ strings::StrAppend(&result_, Indent(4, 6, desc));
+ }
+}
+
+void GenPythonOp::AddDocStringNameArg() {
+ strings::StrAppend(&result_,
+ " name: A name for the operation (optional).\n");
+}
+
+void GenPythonOp::AddOutputGlobals() {
+ // Prepare a NamedTuple type to hold the outputs, if there are multiple
+ if (num_outs_ > 1) {
+ // Prepare the list of output names
+ std::vector<string> out_names(num_outs_);
+ for (int i = 0; i < num_outs_; ++i) {
+ if (!api_def_.out_arg(i).rename_to().empty()) {
+ out_names[i] = api_def_.out_arg(i).rename_to();
+ } else {
+ out_names[i] = strings::StrCat("output", i);
+ }
+ }
+ string out_names_list =
+ strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]");
+
+ // Provide the output names as a Python list
+ string lower_op_name_outputs =
+ strings::StrCat("_", function_name_, "_outputs");
+ const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = ");
+ strings::StrAppend(&prelude_, "\n",
+ WordWrap(outputs_prefix, out_names_list, kRightMargin),
+ "\n");
+
+ strings::StrAppend(&prelude_, "_", op_def_.name(),
+ "Output = _collections.namedtuple(\n");
+ const string tuple_type_prefix = " ";
+ const string tuple_type_suffix = strings::StrCat(
+ "\"", op_def_.name(), "\", ", lower_op_name_outputs, ")");
+ strings::StrAppend(
+ &prelude_, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin),
+ "\n\n");
+ }
+ strings::StrAppend(&prelude_, "\n");
+}
+
+void GenPythonOp::AddDocStringOutputs() {
+ std::vector<string> output_type_string;
+ output_type_string.reserve(num_outs_);
+ for (int i = 0; i < num_outs_; ++i) {
+ output_type_string.push_back(
+ ArgTypeName(op_def_, op_def_.output_arg(i), inferred_attrs_, true));
+ }
+ strings::StrAppend(&result_, GetReturns(op_def_, output_type_string));
+}
+
+void GenPythonOp::AddBody(const string& prefix) {
+ const string apply_prefix =
+ strings::StrCat(prefix, "_result = _op_def_lib.apply_op(");
+ AddBodyNoReturn(apply_prefix);
+ if (num_outs_ > 1) {
+ strings::StrAppend(&result_, prefix, "_result = _", op_def_.name(),
+ "Output._make(_result)\n");
+ }
+ strings::StrAppend(&result_, prefix, "return _result\n");
+}
+
+void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) {
+ string args = strings::StrCat("\"", op_def_.name(), "\", ");
+ for (size_t i = 0; i < param_names_.size(); ++i) {
+ strings::StrAppend(&args, AvoidPythonReserved(param_names_[i].GetName()),
+ "=", param_names_[i].GetRenameTo(), ", ");
+ }
+ strings::StrAppend(&args, "name=name)");
+
+ strings::StrAppend(&result_,
+ // Wrap the arguments, and indent to the (.
+ WordWrap(apply_prefix, args, kRightMargin), "\n");
+}
+
+} // namespace python_op_gen_internal
+} // namespace tensorflow
diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc
index ca6ed42bee..8eb943b960 100644
--- a/tensorflow/python/framework/python_op_gen_main.cc
+++ b/tensorflow/python/framework/python_op_gen_main.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/python/eager/python_eager_op_gen.h"
+#include "tensorflow/python/framework/python_op_gen.h"
#include <memory>
#include <string>
@@ -133,11 +133,10 @@ void PrintAllPythonOps(const std::vector<string>& op_list,
*pruned_ops.mutable_op()->Add() = op_def;
}
}
- PrintEagerPythonOps(pruned_ops, api_def_map, {}, require_shapes,
- source_file_name);
+ PrintPythonOps(pruned_ops, api_def_map, {}, require_shapes,
+ source_file_name);
} else {
- PrintEagerPythonOps(ops, api_def_map, op_list, require_shapes,
- source_file_name);
+ PrintPythonOps(ops, api_def_map, op_list, require_shapes, source_file_name);
}
}
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 8cf24206ed..ca63efbc84 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -50,6 +50,13 @@ def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
[ExtractBitsFromFloat16(x) for x in proto_values])
+def _MediumAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
+ # TODO: Remove the conversion if cython supports np.float16_t
+ fast_tensor_util.AppendFloat16ArrayToTensorProto(
+ tensor_proto,
+ np.asarray(proto_values, dtype=np.float16).view(np.uint16))
+
+
def ExtractBitsFromBFloat16(x):
return np.asscalar(
np.asarray(x, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16))
@@ -64,11 +71,8 @@ if _FAST_TENSOR_UTIL_AVAILABLE:
_NP_TO_APPEND_FN = {
dtypes.bfloat16.as_numpy_dtype:
SlowAppendBFloat16ArrayToTensorProto,
- # TODO(sesse): We should have a
- # fast_tensor_util.AppendFloat16ArrayToTensorProto,
- # but it seems np.float16_t doesn't exist?
np.float16:
- SlowAppendFloat16ArrayToTensorProto,
+ _MediumAppendFloat16ArrayToTensorProto,
np.float32:
fast_tensor_util.AppendFloat32ArrayToTensorProto,
np.float64:
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 523eb67935..b4213f0836 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -40,7 +40,6 @@ py_library(
"_impl/keras/datasets/imdb.py",
"_impl/keras/datasets/mnist.py",
"_impl/keras/datasets/reuters.py",
- "_impl/keras/estimator.py",
"_impl/keras/preprocessing/__init__.py",
"_impl/keras/preprocessing/image.py",
"_impl/keras/preprocessing/sequence.py",
@@ -74,7 +73,6 @@ py_library(
"datasets/imdb/__init__.py",
"datasets/mnist/__init__.py",
"datasets/reuters/__init__.py",
- "estimator/__init__.py",
"initializers/__init__.py",
"layers/__init__.py",
"losses/__init__.py",
@@ -99,8 +97,6 @@ py_library(
":backend",
":engine",
":layers",
- "//tensorflow/python/estimator",
- "//tensorflow/python/estimator:model_fn",
"//tensorflow/python/saved_model",
"//tensorflow/python:training",
],
@@ -316,7 +312,7 @@ py_test(
py_test(
name = "metrics_test",
- size = "small",
+ size = "medium",
srcs = ["_impl/keras/metrics_test.py"],
srcs_version = "PY2AND3",
tags = [
@@ -644,6 +640,7 @@ py_test(
name = "wrappers_test",
size = "medium",
srcs = ["_impl/keras/layers/wrappers_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
tags = [
"noasan", # http://b/78599823
@@ -896,25 +893,6 @@ py_test(
)
py_test(
- name = "estimator_test",
- size = "large",
- srcs = ["_impl/keras/estimator_test.py"],
- srcs_version = "PY2AND3",
- tags = ["notsan"],
- deps = [
- ":keras",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:platform",
- "//tensorflow/python/estimator:numpy_io",
- "//tensorflow/python/estimator:run_config",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
name = "backend_test",
size = "small",
srcs = ["_impl/keras/backend_test.py"],
diff --git a/tensorflow/python/keras/__init__.py b/tensorflow/python/keras/__init__.py
index f56be967ff..f76cfa6608 100644
--- a/tensorflow/python/keras/__init__.py
+++ b/tensorflow/python/keras/__init__.py
@@ -29,7 +29,6 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras import callbacks
from tensorflow.python.keras import constraints
from tensorflow.python.keras import datasets
-from tensorflow.python.keras import estimator
from tensorflow.python.keras import initializers
from tensorflow.python.keras import layers
from tensorflow.python.keras import losses
diff --git a/tensorflow/python/keras/_impl/keras/__init__.py b/tensorflow/python/keras/_impl/keras/__init__.py
index 53f5d31e9c..9bb140bfb8 100644
--- a/tensorflow/python/keras/_impl/keras/__init__.py
+++ b/tensorflow/python/keras/_impl/keras/__init__.py
@@ -25,7 +25,6 @@ from tensorflow.python.keras._impl.keras import callbacks
from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import datasets
from tensorflow.python.keras._impl.keras import engine
-from tensorflow.python.keras._impl.keras import estimator
from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import layers
from tensorflow.python.keras._impl.keras import losses
@@ -40,4 +39,4 @@ from tensorflow.python.keras._impl.keras.layers import Input
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.models import Sequential
-__version__ = '2.1.5-tf'
+__version__ = '2.1.6-tf'
diff --git a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
index 7b7288793d..18a0612e13 100644
--- a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
@@ -240,10 +240,15 @@ def MobileNet(input_shape=None,
'`0.25`, `0.50`, `0.75` or `1.0` only.')
if rows != cols or rows not in [128, 160, 192, 224]:
- raise ValueError('If imagenet weights are being loaded, '
- 'input must have a static square shape (one of '
- '(128,128), (160,160), (192,192), or (224, 224)).'
- ' Input shape provided = %s' % (input_shape,))
+ if rows is None:
+ rows = 224
+ logging.warning('MobileNet shape is undefined.'
+ ' Weights for input shape (224, 224) will be loaded.')
+ else:
+ raise ValueError('If imagenet weights are being loaded, '
+ 'input must have a static square shape (one of '
+ '(128, 128), (160, 160), (192, 192), or (224, 224)).'
+ ' Input shape provided = %s' % (input_shape,))
if K.image_data_format() != 'channels_last':
logging.warning('The MobileNet family of models is only available '
diff --git a/tensorflow/python/keras/_impl/keras/applications/nasnet.py b/tensorflow/python/keras/_impl/keras/applications/nasnet.py
index dd33230a7e..f3412d71be 100644
--- a/tensorflow/python/keras/_impl/keras/applications/nasnet.py
+++ b/tensorflow/python/keras/_impl/keras/applications/nasnet.py
@@ -96,10 +96,9 @@ def NASNet(input_shape=None,
at `~/.keras/keras.json`.
Arguments:
- input_shape: Optional shape tuple, only to be specified
- if `include_top` is False (otherwise the input shape
- has to be `(331, 331, 3)` for NASNetLarge or
- `(224, 224, 3)` for NASNetMobile
+ input_shape: Optional shape tuple, the input shape
+ is by default `(331, 331, 3)` for NASNetLarge and
+ `(224, 224, 3)` for NASNetMobile.
It should have exactly 3 inputs channels,
and width and height should be no smaller than 32.
E.g. `(224, 224, 3)` would be one valid value.
@@ -169,6 +168,14 @@ def NASNet(input_shape=None,
raise ValueError('If using `weights` as ImageNet with `include_top` '
'as true, `classes` should be 1000')
+ if (isinstance(input_shape, tuple) and None in input_shape and
+ weights == 'imagenet'):
+ raise ValueError('When specifying the input shape of a NASNet'
+ ' and loading `ImageNet` weights, '
+ 'the input_shape argument must be static '
+ '(no None entries). Got: `input_shape=' +
+ str(input_shape) + '`.')
+
if default_size is None:
default_size = 331
@@ -178,7 +185,7 @@ def NASNet(input_shape=None,
default_size=default_size,
min_size=32,
data_format=K.image_data_format(),
- require_flatten=include_top or weights,
+ require_flatten=False,
weights=weights)
if K.image_data_format() != 'channels_last':
diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg16.py b/tensorflow/python/keras/_impl/keras/applications/vgg16.py
index cefb25063e..25a15475ea 100644
--- a/tensorflow/python/keras/_impl/keras/applications/vgg16.py
+++ b/tensorflow/python/keras/_impl/keras/applications/vgg16.py
@@ -223,16 +223,6 @@ def VGG16(include_top=True,
cache_subdir='models',
file_hash='6d6bbae143d832006294945121d1f1fc')
model.load_weights(weights_path)
- if K.backend() == 'theano':
- layer_utils.convert_all_kernels_in_model(model)
-
- if K.image_data_format() == 'channels_first':
- if include_top:
- maxpool = model.get_layer(name='block5_pool')
- shape = maxpool.output_shape[1:]
- dense = model.get_layer(name='fc1')
- layer_utils.convert_dense_weights_data_format(dense, shape,
- 'channels_first')
elif weights is not None:
model.load_weights(weights)
diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg19.py b/tensorflow/python/keras/_impl/keras/applications/vgg19.py
index dadaf4fdf0..b09d0068b7 100644
--- a/tensorflow/python/keras/_impl/keras/applications/vgg19.py
+++ b/tensorflow/python/keras/_impl/keras/applications/vgg19.py
@@ -232,16 +232,6 @@ def VGG19(include_top=True,
cache_subdir='models',
file_hash='253f8cb515780f3b799900260a226db6')
model.load_weights(weights_path)
- if K.backend() == 'theano':
- layer_utils.convert_all_kernels_in_model(model)
-
- if K.image_data_format() == 'channels_first':
- if include_top:
- maxpool = model.get_layer(name='block5_pool')
- shape = maxpool.output_shape[1:]
- dense = model.get_layer(name='fc1')
- layer_utils.convert_dense_weights_data_format(dense, shape,
- 'channels_first')
elif weights is not None:
model.load_weights(weights)
diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py
index b1f1270623..af3d1fa33d 100644
--- a/tensorflow/python/keras/_impl/keras/backend.py
+++ b/tensorflow/python/keras/_impl/keras/backend.py
@@ -74,12 +74,6 @@ _SESSION = None
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
_GRAPH_LEARNING_PHASES = {}
-# This dictionary holds a mapping {graph: UID_DICT}.
-# each UID_DICT is a dictionary mapping name prefixes to a current index,
-# used for generating graph-specific string UIDs
-# for various names (e.g. layer names).
-_GRAPH_UID_DICTS = {}
-
# This boolean flag can be set to True to leave variable initialization
# up to the user.
# Change its value via `manual_variable_initialization(value)`.
@@ -298,6 +292,8 @@ def get_uid(prefix=''):
@tf_export('keras.backend.reset_uids')
def reset_uids():
+ """Resets graph identifiers.
+ """
per_graph_layer_name_uids = PER_GRAPH_LAYER_NAME_UIDS
keys = list(per_graph_layer_name_uids.keys())
for key in keys:
@@ -1421,6 +1417,9 @@ def batch_dot(x, y, axes=None):
axes = (axes, axes)
x_ndim = ndim(x)
y_ndim = ndim(y)
+ if axes is None:
+ # behaves like tf.batch_matmul as default
+ axes = [x_ndim - 1, y_ndim - 2]
if x_ndim > y_ndim:
diff = x_ndim - y_ndim
y = array_ops.reshape(y,
@@ -2927,7 +2926,7 @@ def function(inputs, outputs, updates=None, **kwargs):
@tf_export('keras.backend.gradients')
def gradients(loss, variables):
- """Returns the gradients of `variables` w.r.t. `loss`.
+ """Returns the gradients of `loss` w.r.t. `variables`.
Arguments:
loss: Scalar tensor to minimize.
@@ -3395,16 +3394,18 @@ def elu(x, alpha=1.):
@tf_export('keras.backend.softmax')
-def softmax(x):
+def softmax(x, axis=-1):
"""Softmax of a tensor.
Arguments:
x: A tensor or variable.
+ axis: The dimension softmax would be performed on.
+ The default is -1 which indicates the last dimension.
Returns:
A tensor.
"""
- return nn.softmax(x)
+ return nn.softmax(x, axis=axis)
@tf_export('keras.backend.softplus')
@@ -4588,8 +4589,8 @@ def ctc_batch_cost(y_true, y_pred, input_length, label_length):
Tensor with shape (samples,1) containing the
CTC loss of each element.
"""
- label_length = math_ops.to_int32(array_ops.squeeze(label_length))
- input_length = math_ops.to_int32(array_ops.squeeze(input_length))
+ label_length = math_ops.to_int32(array_ops.squeeze(label_length, axis=-1))
+ input_length = math_ops.to_int32(array_ops.squeeze(input_length, axis=-1))
sparse_labels = math_ops.to_int32(
ctc_label_dense_to_sparse(y_true, label_length))
diff --git a/tensorflow/python/keras/_impl/keras/backend_test.py b/tensorflow/python/keras/_impl/keras/backend_test.py
index de1ed467a2..b2243473aa 100644
--- a/tensorflow/python/keras/_impl/keras/backend_test.py
+++ b/tensorflow/python/keras/_impl/keras/backend_test.py
@@ -1122,6 +1122,30 @@ class TestCTC(test.TestCase):
keras.backend.ctc_batch_cost(labels, inputs, input_lens, label_lens))
self.assertAllClose(res[:, 0], loss_log_probs, atol=1e-05)
+ # test when batch_size = 1, that is, one sample only
+ ref = [3.34211]
+ input_lens = np.expand_dims(np.asarray([5]), 1)
+ label_lens = np.expand_dims(np.asarray([5]), 1)
+
+ labels = np.asarray([[0, 1, 2, 1, 0]])
+ inputs = np.asarray(
+ [[[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553], [
+ 0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436
+ ], [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
+ [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
+ [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]]
+ ],
+ dtype=np.float32)
+
+ k_labels = keras.backend.variable(labels, dtype='int32')
+ k_inputs = keras.backend.variable(inputs, dtype='float32')
+ k_input_lens = keras.backend.variable(input_lens, dtype='int32')
+ k_label_lens = keras.backend.variable(label_lens, dtype='int32')
+ res = keras.backend.eval(
+ keras.backend.ctc_batch_cost(k_labels, k_inputs, k_input_lens,
+ k_label_lens))
+ self.assertAllClose(res[:, 0], ref, atol=1e-05)
+
class TestRandomOps(test.TestCase):
diff --git a/tensorflow/python/keras/_impl/keras/callbacks.py b/tensorflow/python/keras/_impl/keras/callbacks.py
index deb1e8867d..7eb8c12af6 100644
--- a/tensorflow/python/keras/_impl/keras/callbacks.py
+++ b/tensorflow/python/keras/_impl/keras/callbacks.py
@@ -268,9 +268,6 @@ class TerminateOnNaN(Callback):
"""Callback that terminates training when a NaN loss is encountered.
"""
- def __init__(self):
- super(TerminateOnNaN, self).__init__()
-
def on_batch_end(self, batch, logs=None):
logs = logs or {}
loss = logs.get('loss')
@@ -468,8 +465,8 @@ class ModelCheckpoint(Callback):
self.model.save(filepath, overwrite=True)
else:
if self.verbose > 0:
- print('\nEpoch %05d: %s did not improve' % (epoch + 1,
- self.monitor))
+ print('\nEpoch %05d: %s did not improve from %0.5f' %
+ (epoch + 1, self.monitor, self.best))
else:
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
@@ -571,25 +568,33 @@ class RemoteMonitor(Callback):
Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
HTTP POST, with a `data` argument which is a
JSON-encoded dictionary of event data.
+ If send_as_json is set to True, the content type of the request will be
+ application/json. Otherwise the serialized JSON will be sent within a form.
Arguments:
root: String; root url of the target server.
path: String; path relative to `root` to which the events will be sent.
field: String; JSON field under which the data will be stored.
+ The field is used only if the payload is sent within a form
+ (i.e. send_as_json is set to False).
headers: Dictionary; optional custom HTTP headers.
+ send_as_json: Boolean; whether the request should be
+ sent as application/json.
"""
def __init__(self,
root='http://localhost:9000',
path='/publish/epoch/end/',
field='data',
- headers=None):
+ headers=None,
+ send_as_json=False):
super(RemoteMonitor, self).__init__()
self.root = root
self.path = path
self.field = field
self.headers = headers
+ self.send_as_json = send_as_json
def on_epoch_end(self, epoch, logs=None):
if requests is None:
@@ -600,9 +605,12 @@ class RemoteMonitor(Callback):
for k, v in logs.items():
send[k] = v
try:
- requests.post(
- self.root + self.path, {self.field: json.dumps(send)},
- headers=self.headers)
+ if self.send_as_json:
+ requests.post(self.root + self.path, json=send, headers=self.headers)
+ else:
+ requests.post(
+ self.root + self.path, {self.field: json.dumps(send)},
+ headers=self.headers)
except requests.exceptions.RequestException:
logging.warning('Warning: could not reach RemoteMonitor '
'root server at ' + str(self.root))
@@ -846,7 +854,7 @@ class ReduceLROnPlateau(Callback):
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity.
- epsilon: threshold for measuring the new optimum,
+ min_delta: threshold for measuring the new optimum,
to only focus on significant changes.
cooldown: number of epochs to wait before resuming
normal operation after lr has been reduced.
@@ -859,17 +867,22 @@ class ReduceLROnPlateau(Callback):
patience=10,
verbose=0,
mode='auto',
- epsilon=1e-4,
+ min_delta=1e-4,
cooldown=0,
- min_lr=0):
+ min_lr=0,
+ **kwargs):
super(ReduceLROnPlateau, self).__init__()
self.monitor = monitor
if factor >= 1.0:
raise ValueError('ReduceLROnPlateau ' 'does not support a factor >= 1.0.')
+ if 'epsilon' in kwargs:
+ min_delta = kwargs.pop('epsilon')
+ logging.warning('`epsilon` argument is deprecated and '
+ 'will be removed, use `min_delta` instead.')
self.factor = factor
self.min_lr = min_lr
- self.epsilon = epsilon
+ self.min_delta = min_delta
self.patience = patience
self.verbose = verbose
self.cooldown = cooldown
@@ -889,10 +902,10 @@ class ReduceLROnPlateau(Callback):
self.mode = 'auto'
if (self.mode == 'min' or
(self.mode == 'auto' and 'acc' not in self.monitor)):
- self.monitor_op = lambda a, b: np.less(a, b - self.epsilon)
+ self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
self.best = np.Inf
else:
- self.monitor_op = lambda a, b: np.greater(a, b + self.epsilon)
+ self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
self.best = -np.Inf
self.cooldown_counter = 0
self.wait = 0
@@ -918,6 +931,7 @@ class ReduceLROnPlateau(Callback):
self.best = current
self.wait = 0
elif not self.in_cooldown():
+ self.wait += 1
if self.wait >= self.patience:
old_lr = float(K.get_value(self.model.optimizer.lr))
if old_lr > self.min_lr:
@@ -929,7 +943,6 @@ class ReduceLROnPlateau(Callback):
'rate to %s.' % (epoch + 1, new_lr))
self.cooldown_counter = self.cooldown
self.wait = 0
- self.wait += 1
def in_cooldown(self):
return self.cooldown_counter > 0
diff --git a/tensorflow/python/keras/_impl/keras/callbacks_test.py b/tensorflow/python/keras/_impl/keras/callbacks_test.py
index 79dfcd1bb6..468e5dddf8 100644
--- a/tensorflow/python/keras/_impl/keras/callbacks_test.py
+++ b/tensorflow/python/keras/_impl/keras/callbacks_test.py
@@ -30,6 +30,7 @@ import numpy as np
from tensorflow.python.keras._impl import keras
from tensorflow.python.keras._impl.keras import testing_utils
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary.writer import writer_cache
try:
@@ -354,7 +355,7 @@ class KerasCallbacksTest(test.TestCase):
keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.1,
- epsilon=10,
+ min_delta=10,
patience=1,
cooldown=5)
]
@@ -371,6 +372,63 @@ class KerasCallbacksTest(test.TestCase):
0.01,
atol=1e-4)
+ model = make_model()
+ cbks = [
+ keras.callbacks.ReduceLROnPlateau(
+ monitor='val_loss',
+ factor=0.1,
+ min_delta=0,
+ patience=1,
+ cooldown=5)
+ ]
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=5,
+ verbose=2)
+ self.assertAllClose(
+ float(keras.backend.get_value(model.optimizer.lr)), 0.1, atol=1e-4)
+
+ def test_ReduceLROnPlateau_patience(self):
+
+ class DummyOptimizer(object):
+
+ def __init__(self):
+ self.lr = keras.backend.variable(1.0)
+
+ class DummyModel(object):
+
+ def __init__(self):
+ self.optimizer = DummyOptimizer()
+
+ reduce_on_plateau = keras.callbacks.ReduceLROnPlateau(
+ monitor='val_loss', patience=2)
+ reduce_on_plateau.model = DummyModel()
+
+ losses = [0.0860, 0.1096, 0.1040]
+ lrs = []
+
+ for epoch in range(len(losses)):
+ reduce_on_plateau.on_epoch_end(epoch, logs={'val_loss': losses[epoch]})
+ lrs.append(keras.backend.get_value(reduce_on_plateau.model.optimizer.lr))
+
+ # The learning rates should be 1.0 except the last one
+ for lr in lrs[:-1]:
+ self.assertEqual(lr, 1.0)
+ self.assertLess(lrs[-1], 1.0)
+
+ def test_ReduceLROnPlateau_backwards_compatibility(self):
+ with test.mock.patch.object(logging, 'warning') as mock_log:
+ reduce_on_plateau = keras.callbacks.ReduceLROnPlateau(epsilon=1e-13)
+ self.assertRegexpMatches(
+ str(mock_log.call_args), '`epsilon` argument is deprecated')
+ self.assertFalse(hasattr(reduce_on_plateau, 'epsilon'))
+ self.assertTrue(hasattr(reduce_on_plateau, 'min_delta'))
+ self.assertEqual(reduce_on_plateau.min_delta, 1e-13)
+
def test_CSVLogger(self):
with self.test_session():
np.random.seed(1337)
@@ -507,33 +565,39 @@ class KerasCallbacksTest(test.TestCase):
assert 'nan' in values[-1], 'The last epoch was not logged.'
def test_TerminateOnNaN(self):
- np.random.seed(1337)
- (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
- train_samples=TRAIN_SAMPLES,
- test_samples=TEST_SAMPLES,
- input_shape=(INPUT_DIM,),
- num_classes=NUM_CLASSES)
+ with self.test_session():
+ np.random.seed(1337)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=TRAIN_SAMPLES,
+ test_samples=TEST_SAMPLES,
+ input_shape=(INPUT_DIM,),
+ num_classes=NUM_CLASSES)
- y_test = keras.utils.to_categorical(y_test)
- y_train = keras.utils.to_categorical(y_train)
- cbks = [keras.callbacks.TerminateOnNaN()]
- model = keras.models.Sequential()
- initializer = keras.initializers.Constant(value=1e5)
- for _ in range(5):
- model.add(keras.layers.Dense(2,
- input_dim=INPUT_DIM,
- activation='relu',
- kernel_initializer=initializer))
- model.add(keras.layers.Dense(NUM_CLASSES))
- model.compile(loss='mean_squared_error',
- optimizer='rmsprop')
-
- history = model.fit(x_train, y_train, batch_size=BATCH_SIZE,
- validation_data=(x_test, y_test),
- callbacks=cbks, epochs=20)
- loss = history.history['loss']
- assert len(loss) == 1
- assert loss[0] == np.inf
+ y_test = keras.utils.to_categorical(y_test)
+ y_train = keras.utils.to_categorical(y_train)
+ cbks = [keras.callbacks.TerminateOnNaN()]
+ model = keras.models.Sequential()
+ initializer = keras.initializers.Constant(value=1e5)
+ for _ in range(5):
+ model.add(
+ keras.layers.Dense(
+ 2,
+ input_dim=INPUT_DIM,
+ activation='relu',
+ kernel_initializer=initializer))
+ model.add(keras.layers.Dense(NUM_CLASSES))
+ model.compile(loss='mean_squared_error', optimizer='rmsprop')
+
+ history = model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=20)
+ loss = history.history['loss']
+ assert len(loss) == 1
+ assert loss[0] == np.inf
def test_TensorBoard(self):
np.random.seed(1337)
@@ -875,6 +939,37 @@ class KerasCallbacksTest(test.TestCase):
assert os.path.exists(temp_dir)
+ def test_RemoteMonitorWithJsonPayload(self):
+ if h5py is None:
+ self.skipTest('`requests` required to run this test')
+ with self.test_session():
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=TRAIN_SAMPLES,
+ test_samples=TEST_SAMPLES,
+ input_shape=(INPUT_DIM,),
+ num_classes=NUM_CLASSES)
+ y_test = keras.utils.np_utils.to_categorical(y_test)
+ y_train = keras.utils.np_utils.to_categorical(y_train)
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.Dense(
+ NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
+ model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='rmsprop',
+ metrics=['accuracy'])
+ cbks = [keras.callbacks.RemoteMonitor(send_as_json=True)]
+
+ with test.mock.patch.object(requests, 'post'):
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=1)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
index 16ee2952b2..5dc93806f4 100644
--- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
@@ -25,7 +25,7 @@ import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
+from tensorflow.python.estimator import util as function_utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -44,6 +44,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training import checkpointable
+from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@@ -146,7 +147,7 @@ class Layer(checkpointable.CheckpointableBase):
# return tensors. When using graph execution, _losses is a list of ops.
self._losses = []
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
- self._call_fn_args = estimator_util.fn_args(self.call)
+ self._call_fn_args = function_utils.fn_args(self.call)
self._compute_previous_mask = ('mask' in self._call_fn_args or
hasattr(self, 'compute_mask'))
self._uses_inputs_arg = True
@@ -484,8 +485,7 @@ class Layer(checkpointable.CheckpointableBase):
"""
if dtype is None:
dtype = self.dtype or backend.floatx()
- else:
- dtype = dtypes.as_dtype(dtype)
+ dtype = dtypes.as_dtype(dtype)
initializer = initializers.get(initializer)
regularizer = regularizers.get(regularizer)
constraint = constraints.get(constraint)
@@ -513,7 +513,7 @@ class Layer(checkpointable.CheckpointableBase):
# Manage errors in Layer rather than Checkpointable.
overwrite=True,
initializer=initializer,
- dtype=dtypes.as_dtype(dtype),
+ dtype=dtype,
constraint=constraint,
trainable=trainable and self.trainable,
partitioner=partitioner,
@@ -644,7 +644,7 @@ class Layer(checkpointable.CheckpointableBase):
self._compute_previous_mask):
previous_mask = collect_previous_mask(inputs)
if not hasattr(self, '_call_fn_args'):
- self._call_fn_args = estimator_util.fn_args(self.call)
+ self._call_fn_args = function_utils.fn_args(self.call)
if ('mask' in self._call_fn_args and 'mask' not in kwargs and
not generic_utils.is_all_none(previous_mask)):
# The previous layer generated a mask, and mask was not explicitly pass
diff --git a/tensorflow/python/keras/_impl/keras/engine/network.py b/tensorflow/python/keras/_impl/keras/engine/network.py
index b7fab6e974..f7cf44af59 100644
--- a/tensorflow/python/keras/_impl/keras/engine/network.py
+++ b/tensorflow/python/keras/_impl/keras/engine/network.py
@@ -839,10 +839,14 @@ class Network(base_layer.Layer):
output_tensors = nest.flatten(
layer.call(computed_tensor, **kwargs))
if hasattr(layer, 'compute_mask'):
- output_masks = nest.flatten(
- layer.compute_mask(computed_tensor, computed_mask))
+ output_masks = layer.compute_mask(computed_tensor,
+ computed_mask)
+ if output_masks is None:
+ output_masks = [None for _ in output_tensors]
+ else:
+ output_masks = nest.flatten(output_masks)
else:
- output_masks = [None for _ in range(len(output_tensors))]
+ output_masks = [None for _ in output_tensors]
computed_tensors = [computed_tensor]
computed_masks = [computed_mask]
else:
@@ -855,11 +859,16 @@ class Network(base_layer.Layer):
output_tensors = nest.flatten(
layer.call(computed_tensors, **kwargs))
+
if hasattr(layer, 'compute_mask'):
- output_masks = nest.flatten(
- layer.compute_mask(computed_tensors, computed_masks))
+ output_masks = layer.compute_mask(computed_tensors,
+ computed_masks)
+ if output_masks is None:
+ output_masks = [None for _ in output_tensors]
+ else:
+ output_masks = nest.flatten(output_masks)
else:
- output_masks = [None for _ in range(len(output_tensors))]
+ output_masks = [None for _ in output_tensors]
if not context.executing_eagerly():
if layer.activity_regularizer is not None:
diff --git a/tensorflow/python/keras/_impl/keras/engine/saving.py b/tensorflow/python/keras/_impl/keras/engine/saving.py
index a0b709a1a5..6a3ae3b20c 100644
--- a/tensorflow/python/keras/_impl/keras/engine/saving.py
+++ b/tensorflow/python/keras/_impl/keras/engine/saving.py
@@ -30,6 +30,7 @@ from tensorflow.python.keras._impl.keras import optimizers
from tensorflow.python.keras._impl.keras.utils import conv_utils
from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import serialization
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=g-import-not-at-top
@@ -61,7 +62,9 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
Arguments:
model: Keras model instance to be saved.
- filepath: String, path where to save the model.
+ filepath: One of the following:
+ - String, path where to save the model
+ - `h5py.File` object where to save the model
overwrite: Whether we should overwrite any existing
model at the target location, or instead
ask the user with a manual prompt.
@@ -74,49 +77,22 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
if h5py is None:
raise ImportError('`save_model` requires h5py.')
- def get_json_type(obj):
- """Serializes any object to a JSON-serializable structure.
-
- Arguments:
- obj: the object to serialize
-
- Returns:
- JSON-serializable structure representing `obj`.
-
- Raises:
- TypeError: if `obj` cannot be serialized.
- """
- # if obj is a serializable Keras class instance
- # e.g. optimizer, layer
- if hasattr(obj, 'get_config'):
- return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
-
- # if obj is any numpy type
- if type(obj).__module__ == np.__name__:
- if isinstance(obj, np.ndarray):
- return {'type': type(obj), 'value': obj.tolist()}
- else:
- return obj.item()
-
- # misc functions (e.g. loss function)
- if callable(obj):
- return obj.__name__
-
- # if obj is a python 'type'
- if type(obj).__name__ == type.__name__:
- return obj.__name__
-
- raise TypeError('Not JSON Serializable:', obj)
-
from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
- # If file exists and should not be overwritten.
- if not overwrite and os.path.isfile(filepath):
- proceed = ask_to_proceed_with_overwrite(filepath)
- if not proceed:
- return
+ if not isinstance(filepath, h5py.File):
+ # If file exists and should not be overwritten.
+ if not overwrite and os.path.isfile(filepath):
+ proceed = ask_to_proceed_with_overwrite(filepath)
+ if not proceed:
+ return
- with h5py.File(filepath, mode='w') as f:
+ f = h5py.File(filepath, mode='w')
+ opened_new_file = True
+ else:
+ f = filepath
+ opened_new_file = False
+
+ try:
f.attrs['keras_version'] = str(keras_version).encode('utf8')
f.attrs['backend'] = K.backend().encode('utf8')
f.attrs['model_config'] = json.dumps(
@@ -124,7 +100,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
'class_name': model.__class__.__name__,
'config': model.get_config()
},
- default=get_json_type).encode('utf8')
+ default=serialization.get_json_type).encode('utf8')
model_weights_group = f.create_group('model_weights')
model_layers = model.layers
@@ -154,7 +130,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
'sample_weight_mode': model.sample_weight_mode,
'loss_weights': model.loss_weights,
},
- default=get_json_type).encode('utf8')
+ default=serialization.get_json_type).encode('utf8')
# Save optimizer weights.
symbolic_weights = getattr(model.optimizer, 'weights')
@@ -175,6 +151,9 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
else:
param_dset[:] = val
f.flush()
+ finally:
+ if opened_new_file:
+ f.close()
@tf_export('keras.models.load_model')
@@ -182,7 +161,9 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
"""Loads a model saved via `save_model`.
Arguments:
- filepath: String, path to the saved model.
+ filepath: One of the following:
+ - String, path to the saved model
+ - `h5py.File` object from which to load the model
custom_objects: Optional dictionary mapping names
(strings) to custom classes or functions to be
considered during deserialization.
@@ -232,7 +213,14 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
return custom_objects[obj]
return obj
- with h5py.File(filepath, mode='r') as f:
+ opened_new_file = not isinstance(filepath, h5py.File)
+ if opened_new_file:
+ f = h5py.File(filepath, mode='r')
+ else:
+ f = filepath
+
+ model = None
+ try:
# instantiate model
model_config = f.attrs.get('model_config')
if model_config is None:
@@ -243,54 +231,54 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
# set weights
load_weights_from_hdf5_group(f['model_weights'], model.layers)
- # Early return if compilation is not required.
- if not compile:
- return model
-
- # instantiate optimizer
- training_config = f.attrs.get('training_config')
- if training_config is None:
- logging.warning('No training configuration found in save file: '
- 'the model was *not* compiled. Compile it manually.')
- return model
- training_config = json.loads(training_config.decode('utf-8'))
- optimizer_config = training_config['optimizer_config']
- optimizer = optimizers.deserialize(
- optimizer_config, custom_objects=custom_objects)
-
- # Recover loss functions and metrics.
- loss = convert_custom_objects(training_config['loss'])
- metrics = convert_custom_objects(training_config['metrics'])
- sample_weight_mode = training_config['sample_weight_mode']
- loss_weights = training_config['loss_weights']
-
- # Compile model.
- model.compile(
- optimizer=optimizer,
- loss=loss,
- metrics=metrics,
- loss_weights=loss_weights,
- sample_weight_mode=sample_weight_mode)
-
- # Set optimizer weights.
- if 'optimizer_weights' in f:
- # Build train function (to get weight updates).
- model._make_train_function()
- optimizer_weights_group = f['optimizer_weights']
- optimizer_weight_names = [
- n.decode('utf8')
- for n in optimizer_weights_group.attrs['weight_names']
- ]
- optimizer_weight_values = [
- optimizer_weights_group[n] for n in optimizer_weight_names
- ]
- try:
- model.optimizer.set_weights(optimizer_weight_values)
- except ValueError:
- logging.warning('Error in loading the saved optimizer '
- 'state. As a result, your model is '
- 'starting with a freshly initialized '
- 'optimizer.')
+ if compile:
+ # instantiate optimizer
+ training_config = f.attrs.get('training_config')
+ if training_config is None:
+ logging.warning('No training configuration found in save file: '
+ 'the model was *not* compiled. Compile it manually.')
+ return model
+ training_config = json.loads(training_config.decode('utf-8'))
+ optimizer_config = training_config['optimizer_config']
+ optimizer = optimizers.deserialize(
+ optimizer_config, custom_objects=custom_objects)
+
+ # Recover loss functions and metrics.
+ loss = convert_custom_objects(training_config['loss'])
+ metrics = convert_custom_objects(training_config['metrics'])
+ sample_weight_mode = training_config['sample_weight_mode']
+ loss_weights = training_config['loss_weights']
+
+ # Compile model.
+ model.compile(
+ optimizer=optimizer,
+ loss=loss,
+ metrics=metrics,
+ loss_weights=loss_weights,
+ sample_weight_mode=sample_weight_mode)
+
+ # Set optimizer weights.
+ if 'optimizer_weights' in f:
+ # Build train function (to get weight updates).
+ model._make_train_function()
+ optimizer_weights_group = f['optimizer_weights']
+ optimizer_weight_names = [
+ n.decode('utf8')
+ for n in optimizer_weights_group.attrs['weight_names']
+ ]
+ optimizer_weight_values = [
+ optimizer_weights_group[n] for n in optimizer_weight_names
+ ]
+ try:
+ model.optimizer.set_weights(optimizer_weight_values)
+ except ValueError:
+ logging.warning('Error in loading the saved optimizer '
+ 'state. As a result, your model is '
+ 'starting with a freshly initialized '
+ 'optimizer.')
+ finally:
+ if opened_new_file:
+ f.close()
return model
@@ -669,6 +657,12 @@ def _convert_rnn_weights(layer, weights):
def save_weights_to_hdf5_group(f, layers):
+ """Saves the weights of a list of layers to a HDF5 group.
+
+ Arguments:
+ f: HDF5 group.
+ layers: List of layer instances.
+ """
from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
save_attributes_to_hdf5_group(
@@ -743,7 +737,7 @@ def load_weights_from_hdf5_group(f, layers):
for k, name in enumerate(layer_names):
g = f[name]
weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
- weight_values = [g[weight_name] for weight_name in weight_names]
+ weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
layer = filtered_layers[k]
symbolic_weights = layer.weights
weight_values = preprocess_weights_for_loading(
@@ -799,7 +793,7 @@ def load_weights_from_hdf5_group_by_name(f, layers):
for k, name in enumerate(layer_names):
g = f[name]
weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
- weight_values = [g[weight_name] for weight_name in weight_names]
+ weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
for layer in index.get(name, []):
symbolic_weights = layer.weights
diff --git a/tensorflow/python/keras/_impl/keras/engine/saving_test.py b/tensorflow/python/keras/_impl/keras/engine/saving_test.py
index c0b16b6bf5..acd104b4fb 100644
--- a/tensorflow/python/keras/_impl/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/saving_test.py
@@ -253,7 +253,7 @@ class TestWholeModelSaving(test.TestCase):
def test_sequential_model_saving(self):
if h5py is None:
- return # Skip test if models cannot be saved.
+ self.skipTest('h5py required to run this test')
with self.test_session():
model = keras.models.Sequential()
@@ -290,7 +290,7 @@ class TestWholeModelSaving(test.TestCase):
def test_sequential_model_saving_2(self):
if h5py is None:
- return # Skip test if models cannot be saved.
+ self.skipTest('h5py required to run this test')
with self.test_session():
# test with custom optimizer, loss
@@ -326,7 +326,7 @@ class TestWholeModelSaving(test.TestCase):
def test_functional_model_saving(self):
if h5py is None:
- return # Skip test if models cannot be saved.
+ self.skipTest('h5py required to run this test')
with self.test_session():
inputs = keras.layers.Input(shape=(3,))
@@ -354,7 +354,7 @@ class TestWholeModelSaving(test.TestCase):
def test_saving_without_compilation(self):
if h5py is None:
- return # Skip test if models cannot be saved.
+ self.skipTest('h5py required to run this test')
with self.test_session():
model = keras.models.Sequential()
@@ -370,7 +370,7 @@ class TestWholeModelSaving(test.TestCase):
def test_saving_with_tf_optimizer(self):
if h5py is None:
- return # Skip test if models cannot be saved.
+ self.skipTest('h5py required to run this test')
with self.test_session():
model = keras.models.Sequential()
@@ -388,7 +388,7 @@ class TestWholeModelSaving(test.TestCase):
def test_saving_right_after_compilation(self):
if h5py is None:
- return # Skip test if models cannot be saved.
+ self.skipTest('h5py required to run this test')
with self.test_session():
model = keras.models.Sequential()
@@ -405,7 +405,7 @@ class TestWholeModelSaving(test.TestCase):
def test_saving_lambda_numpy_array_arguments(self):
if h5py is None:
- return # Skip test if models cannot be saved.
+ self.skipTest('h5py required to run this test')
mean = np.random.random((4, 2, 3))
std = np.abs(np.random.random((4, 2, 3))) + 1e-5
@@ -427,7 +427,7 @@ class TestWholeModelSaving(test.TestCase):
def test_saving_model_with_long_layer_names(self):
if h5py is None:
- return # Skip test if models cannot be saved.
+ self.skipTest('h5py required to run this test')
with self.test_session():
# This layer name will make the `layers_name` HDF5 attribute blow
@@ -468,7 +468,7 @@ class TestWholeModelSaving(test.TestCase):
def test_saving_model_with_long_weights_names(self):
if h5py is None:
- return # Skip test if models cannot be saved.
+ self.skipTest('h5py required to run this test')
with self.test_session():
x = keras.Input(shape=(2,), name='nested_model_input')
@@ -511,6 +511,43 @@ class TestWholeModelSaving(test.TestCase):
os.close(fd)
os.remove(fname)
+ def test_model_saving_to_pre_created_h5py_file(self):
+ if h5py is None:
+ self.skipTest('h5py required to run this test')
+
+ with self.test_session():
+ inputs = keras.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ outputs = keras.layers.Dense(3)(x)
+
+ model = keras.Model(inputs, outputs)
+ model.compile(loss=keras.losses.MSE,
+ optimizer=keras.optimizers.Adam(),
+ metrics=[keras.metrics.categorical_accuracy])
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ out = model.predict(x)
+ fd, fname = tempfile.mkstemp('.h5')
+ with h5py.File(fname, mode='r+') as h5file:
+ keras.models.save_model(model, h5file)
+ loaded_model = keras.models.load_model(h5file)
+ out2 = loaded_model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ # Test non-default options in h5
+ with h5py.File('_', driver='core',
+ backing_store=False) as h5file:
+ keras.models.save_model(model, h5file)
+ loaded_model = keras.models.load_model(h5file)
+ out2 = loaded_model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ # Cleanup
+ os.close(fd)
+ os.remove(fname)
+
class SubclassedModel(training.Model):
diff --git a/tensorflow/python/keras/_impl/keras/engine/topology_test.py b/tensorflow/python/keras/_impl/keras/engine/topology_test.py
index 6993a04289..635c446879 100644
--- a/tensorflow/python/keras/_impl/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/topology_test.py
@@ -883,6 +883,33 @@ class TopologyConstructionTest(test.TestCase):
preds = model.predict(x)
self.assertEqual(np.min(preds), 0.) # At least one unit was dropped.
+ def test_multi_output_model_with_none_masking(self):
+
+ with self.test_session():
+ def func(x):
+ return [x * 0.2, x * 0.3]
+
+ def output_shape(input_shape):
+ return [input_shape, input_shape]
+
+ i = keras.layers.Input(shape=(3, 2, 1))
+ o = keras.layers.Lambda(function=func, output_shape=output_shape)(i)
+
+ self.assertEqual(keras.backend.int_shape(o[0]), (None, 3, 2, 1))
+ self.assertEqual(keras.backend.int_shape(o[1]), (None, 3, 2, 1))
+
+ o = keras.layers.add(o)
+ model = keras.Model(i, o)
+
+ i2 = keras.layers.Input(shape=(3, 2, 1))
+ o2 = model(i2)
+ model2 = keras.Model(i2, o2)
+
+ x = np.random.random((4, 3, 2, 1))
+ out = model2.predict(x)
+ assert out.shape == (4, 3, 2, 1)
+ self.assertAllClose(out, x * 0.2 + x * 0.3, atol=1e-4)
+
class DeferredModeTest(test.TestCase):
diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
index c7623d2b52..16d1b160e4 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -285,6 +285,10 @@ class Model(Network):
self.metrics_names.append(self.output_names[i] + '_loss')
self.nested_metrics = training_utils.collect_metrics(metrics,
self.output_names)
+ # TODO(fchollet): support stateful metrics in eager execution.
+ self.stateful_metric_functions = []
+ self.stateful_metric_names = []
+
with K.name_scope('metrics'):
training_utils.populate_metric_names(self)
self._feed_sample_weight_modes = []
@@ -461,6 +465,7 @@ class Model(Network):
self.output_names)
self.metrics_updates = []
self.stateful_metric_names = []
+ self.stateful_metric_functions = []
with K.name_scope('metrics'):
for i in range(len(self.outputs)):
if i in skip_target_indices:
@@ -516,8 +521,9 @@ class Model(Network):
# Keep track of state updates created by
# stateful metrics (i.e. metrics layers).
- if isinstance(metric_fn, Layer):
+ if isinstance(metric_fn, Layer) and metric_fn.stateful:
self.stateful_metric_names.append(metric_name)
+ self.stateful_metric_functions.append(metric_fn)
self.metrics_updates += metric_fn.updates
handle_metrics(output_metrics)
@@ -1745,7 +1751,8 @@ class Model(Network):
steps=None,
max_queue_size=10,
workers=1,
- use_multiprocessing=False):
+ use_multiprocessing=False,
+ verbose=0):
"""Evaluates the model on a data generator.
The generator should return the same kind of data
@@ -1772,6 +1779,7 @@ class Model(Network):
Note that because this implementation relies on multiprocessing,
you should not pass non-picklable arguments to the generator
as they can't be passed easily to children processes.
+ verbose: Verbosity mode, 0 or 1.
Returns:
Scalar test loss (if the model has a single output and no metrics)
@@ -1796,7 +1804,8 @@ class Model(Network):
steps=steps,
max_queue_size=max_queue_size,
workers=workers,
- use_multiprocessing=use_multiprocessing)
+ use_multiprocessing=use_multiprocessing,
+ verbose=verbose)
def predict_generator(self,
generator,
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_arrays.py b/tensorflow/python/keras/_impl/keras/engine/training_arrays.py
index 12e74ef51d..84f93da898 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_arrays.py
@@ -27,7 +27,6 @@ from tensorflow.python.framework import errors
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import callbacks as cbks
from tensorflow.python.keras._impl.keras.engine import training_utils
-from tensorflow.python.keras._impl.keras.engine.base_layer import Layer
from tensorflow.python.keras._impl.keras.utils.generic_utils import make_batches
from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays
@@ -180,9 +179,8 @@ def fit_loop(model,
for epoch in range(initial_epoch, epochs):
# Reset stateful metrics
- for m in model.metrics:
- if isinstance(m, Layer):
- m.reset_states()
+ for m in model.stateful_metric_functions:
+ m.reset_states()
# Update callbacks
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
@@ -413,9 +411,8 @@ def test_loop(model, inputs, targets,
ins = inputs + targets + sample_weights
if hasattr(model, 'metrics'):
- for m in model.metrics:
- if isinstance(m, Layer):
- m.reset_states()
+ for m in model.stateful_metric_functions:
+ m.reset_states()
stateful_metric_indices = [
i for i, name in enumerate(model.metrics_names)
if str(name) in model.stateful_metric_names
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py
index 3617eb281a..0a98fc2452 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py
@@ -501,11 +501,11 @@ def iterator_test_loop(model, inputs, steps, verbose=0):
if verbose == 1:
progbar.update(step_index + 1)
- for i in range(len(outs)):
- outs[i] /= num_samples
- if len(outs) == 1:
- return outs[0]
- return outs
+ for i in range(len(outs)):
+ outs[i] /= num_samples
+ if len(outs) == 1:
+ return outs[0]
+ return outs
def batch_test_loop(model,
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
index 5adb3ef940..2031a8a3dc 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras._impl import keras
@@ -94,7 +95,7 @@ class TrainingTest(test.TestCase):
verbose=2)
model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
- # Test with validation split
+ # Test with validation split
model.fit(
[input_a_np, input_b_np], [output_d_np, output_e_np],
epochs=2,
@@ -402,6 +403,24 @@ class TrainingTest(test.TestCase):
model.train_on_batch(inputs, targets)
model.test_on_batch(inputs, targets)
+ def test_generator_methods(self):
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(4, input_shape=(3,)))
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ model.compile(optimizer, 'mse', metrics=['mae'])
+
+ x = np.random.random((10, 3))
+ y = np.random.random((10, 4))
+
+ def iterator():
+ while 1:
+ yield x, y
+
+ model.fit_generator(iterator(), steps_per_epoch=3, epochs=1)
+ model.evaluate_generator(iterator(), steps=3)
+ out = model.predict_generator(iterator(), steps=3)
+ self.assertEqual(out.shape, (30, 4))
+
class LossWeightingTest(test.TestCase):
@@ -670,6 +689,59 @@ class CorrectnessTest(test.TestCase):
outs = model.evaluate(x, y)
self.assertEqual(outs[1], 0.)
+ @tf_test_util.run_in_graph_and_eager_modes()
+ def test_loss_correctness_with_iterator(self):
+ # Test that training loss is the same in eager and graph
+ # (by comparing it to a reference value in a deterministic case)
+ model = keras.Sequential()
+ model.add(
+ keras.layers.Dense(
+ 3, activation='relu', input_dim=4, kernel_initializer='ones'))
+ model.add(
+ keras.layers.Dense(2, activation='softmax', kernel_initializer='ones'))
+ model.compile(
+ loss='sparse_categorical_crossentropy',
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+ x = np.ones((100, 4), dtype=np.float32)
+ np.random.seed(123)
+ y = np.random.randint(0, 1, size=(100, 1))
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+ history = model.fit(iterator, epochs=1, steps_per_epoch=10)
+ self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173)
+
+ @tf_test_util.run_in_graph_and_eager_modes()
+ def test_metrics_correctness_with_iterator(self):
+ model = keras.Sequential()
+ model.add(
+ keras.layers.Dense(
+ 8, activation='relu', input_dim=4, kernel_initializer='ones'))
+ model.add(
+ keras.layers.Dense(1, activation='sigmoid', kernel_initializer='ones'))
+ model.compile(
+ loss='binary_crossentropy',
+ metrics=['accuracy'],
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+ np.random.seed(123)
+ x = np.random.randint(10, size=(100, 4)).astype(np.float32)
+ y = np.random.randint(2, size=(100, 1)).astype(np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+ outs = model.evaluate(iterator, steps=10)
+ self.assertEqual(np.around(outs[1], decimals=1), 0.5)
+
+ y = np.zeros((100, 1), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+ outs = model.evaluate(iterator, steps=10)
+ self.assertEqual(outs[1], 0.)
+
+
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_generator.py b/tensorflow/python/keras/_impl/keras/engine/training_generator.py
index 58b5bc39c1..0de8297795 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_generator.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_generator.py
@@ -49,9 +49,6 @@ def fit_generator(model,
epoch = initial_epoch
do_validation = bool(validation_data)
- model._make_train_function()
- if do_validation:
- model._make_test_function()
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
@@ -155,6 +152,8 @@ def fit_generator(model,
# Construct epoch logs.
epoch_logs = {}
while epoch < epochs:
+ for m in model.stateful_metric_functions:
+ m.reset_states()
callbacks.on_epoch_begin(epoch)
steps_done = 0
batch_index = 0
@@ -250,9 +249,18 @@ def evaluate_generator(model,
steps=None,
max_queue_size=10,
workers=1,
- use_multiprocessing=False):
+ use_multiprocessing=False,
+ verbose=0):
"""See docstring for `Model.evaluate_generator`."""
- model._make_test_function()
+ stateful_metric_indices = []
+ if hasattr(model, 'metrics'):
+ for m in model.stateful_metric_functions:
+ m.reset_states()
+ stateful_metric_indices = [
+ i for i, name in enumerate(model.metrics_names)
+ if str(name) in model.stateful_metric_names]
+ else:
+ stateful_metric_indices = []
steps_done = 0
wait_time = 0.01
@@ -293,6 +301,9 @@ def evaluate_generator(model,
else:
output_generator = generator
+ if verbose == 1:
+ progbar = Progbar(target=steps)
+
while steps_done < steps:
generator_output = next(output_generator)
if not hasattr(generator_output, '__len__'):
@@ -323,6 +334,8 @@ def evaluate_generator(model,
steps_done += 1
batch_sizes.append(batch_size)
+ if verbose == 1:
+ progbar.update(steps_done)
finally:
if enqueuer is not None:
@@ -333,8 +346,11 @@ def evaluate_generator(model,
else:
averages = []
for i in range(len(outs)):
- averages.append(
- np.average([out[i] for out in all_outs], weights=batch_sizes))
+ if i not in stateful_metric_indices:
+ averages.append(
+ np.average([out[i] for out in all_outs], weights=batch_sizes))
+ else:
+ averages.append(float(all_outs[-1][i]))
return averages
@@ -346,8 +362,6 @@ def predict_generator(model,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.predict_generator`."""
- model._make_predict_function()
-
steps_done = 0
wait_time = 0.01
all_outs = []
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py
index cc2386a5bd..4b01fbb165 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py
@@ -947,6 +947,7 @@ class TestGeneratorMethods(test.TestCase):
steps=5,
max_queue_size=10,
workers=2,
+ verbose=1,
use_multiprocessing=True)
model.evaluate_generator(custom_generator(),
steps=5,
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
index 9971f12773..e47aaf9cac 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
@@ -1467,10 +1467,14 @@ class SeparableConv2D(SeparableConv):
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
+ dilation_rate: An integer or tuple/list of 2 integers, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any `strides` value != 1.
depth_multiplier: The number of depthwise convolution output channels
for each input channel.
The total number of depthwise convolution output
- channels will be equal to `filterss_in * depth_multiplier`.
+ channels will be equal to `filters_in * depth_multiplier`.
activation: Activation function to use.
If you don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
@@ -1511,7 +1515,7 @@ class SeparableConv2D(SeparableConv):
strides=(1, 1),
padding='valid',
data_format=None,
- dilation_rate=1,
+ dilation_rate=(1, 1),
depth_multiplier=1,
activation=None,
use_bias=True,
@@ -2095,14 +2099,14 @@ class ZeroPadding3D(Layer):
"""Zero-padding layer for 3D data (spatial or spatio-temporal).
Arguments:
- padding: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
+ padding: int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints.
- If int: the same symmetric padding
is applied to width and height.
- - If tuple of 2 ints:
+ - If tuple of 3 ints:
interpreted as two different
symmetric padding values for height and width:
`(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad)`.
- - If tuple of 2 tuples of 2 ints:
+ - If tuple of 3 tuples of 2 ints:
interpreted as
`((left_dim1_pad, right_dim1_pad), (left_dim2_pad,
right_dim2_pad), (left_dim3_pad, right_dim3_pad))`
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
index be25bbc043..9cad08274e 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
@@ -29,6 +29,7 @@ from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
from tensorflow.python.keras._impl.keras.layers.recurrent import _generate_dropout_mask
+from tensorflow.python.keras._impl.keras.layers.recurrent import _standardize_args
from tensorflow.python.keras._impl.keras.layers.recurrent import RNN
from tensorflow.python.keras._impl.keras.utils import conv_utils
from tensorflow.python.keras._impl.keras.utils import generic_utils
@@ -167,6 +168,7 @@ class ConvRNN2D(RNN):
**kwargs)
self.input_spec = [InputSpec(ndim=5)]
self.states = None
+ self._num_constants = None
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
@@ -214,7 +216,7 @@ class ConvRNN2D(RNN):
# Note input_shape will be list of shapes of initial states and
# constants if these are passed in __call__.
if self._num_constants is not None:
- constants_shape = input_shape[-self._num_constants:]
+ constants_shape = input_shape[-self._num_constants:] # pylint: disable=E1130
else:
constants_shape = None
@@ -279,8 +281,8 @@ class ConvRNN2D(RNN):
return [initial_state]
def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
- inputs, initial_state, constants = self._standardize_args(
- inputs, initial_state, constants)
+ inputs, initial_state, constants = _standardize_args(
+ inputs, initial_state, constants, self._num_constants)
if initial_state is None and constants is None:
return super(ConvRNN2D, self).__call__(inputs, **kwargs)
@@ -609,16 +611,25 @@ class ConvLSTM2DCell(Layer):
name='recurrent_kernel',
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
+
if self.use_bias:
- self.bias = self.add_weight(shape=(self.filters * 4,),
- initializer=self.bias_initializer,
- name='bias',
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint)
if self.unit_forget_bias:
- bias_value = np.zeros((self.filters * 4,))
- bias_value[self.filters: self.filters * 2] = 1.
- K.set_value(self.bias, bias_value)
+
+ def bias_initializer(_, *args, **kwargs):
+ return K.concatenate([
+ self.bias_initializer((self.filters,), *args, **kwargs),
+ initializers.Ones()((self.filters,), *args, **kwargs),
+ self.bias_initializer((self.filters * 2,), *args, **kwargs),
+ ])
+ else:
+ bias_initializer = self.bias_initializer
+ self.bias = self.add_weight(
+ shape=(self.filters * 4,),
+ name='bias',
+ initializer=bias_initializer,
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint)
+
else:
self.bias = None
@@ -844,10 +855,10 @@ class ConvLSTM2D(ConvRNN2D):
Input shape:
- if data_format='channels_first'
5D tensor with shape:
- `(samples,time, channels, rows, cols)`
+ `(samples, time, channels, rows, cols)`
- if data_format='channels_last'
5D tensor with shape:
- `(samples,time, rows, cols, channels)`
+ `(samples, time, rows, cols, channels)`
Output shape:
- if `return_sequences`
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py
index 9e768b4e95..827a7ffbda 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py
@@ -180,6 +180,23 @@ class ConvLSTMTest(test.TestCase):
'recurrent_dropout': 0.1},
input_shape=(1, 2, 5, 5, 2))
+ def test_conv_lstm_cloning(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.ConvLSTM2D(5, 3, input_shape=(None, 5, 5, 3)))
+
+ test_inputs = np.random.random((2, 4, 5, 5, 3))
+ reference_outputs = model.predict(test_inputs)
+ weights = model.get_weights()
+
+ # Use a new graph to clone the model
+ with self.test_session():
+ clone = keras.models.clone_model(model)
+ clone.set_weights(weights)
+
+ outputs = clone.predict(test_inputs)
+ self.assertAllClose(reference_outputs, outputs, atol=1e-5)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/layers/core.py b/tensorflow/python/keras/_impl/keras/layers/core.py
index 9c4cb0f4fd..30327781df 100644
--- a/tensorflow/python/keras/_impl/keras/layers/core.py
+++ b/tensorflow/python/keras/_impl/keras/layers/core.py
@@ -33,6 +33,7 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.utils import conv_utils
from tensorflow.python.keras._impl.keras.utils import generic_utils
from tensorflow.python.keras._impl.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
@@ -501,6 +502,17 @@ class Permute(Layer):
class Flatten(Layer):
"""Flattens the input. Does not affect the batch size.
+ Arguments:
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, ..., channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, ...)`.
+ It defaults to the `image_data_format` value found in your
+ Keras config file at `~/.keras/keras.json`.
+ If you never set it, then it will be "channels_last".
+
Example:
```python
@@ -515,11 +527,19 @@ class Flatten(Layer):
```
"""
- def __init__(self, **kwargs):
+ def __init__(self, data_format=None, **kwargs):
super(Flatten, self).__init__(**kwargs)
+ self.data_format = conv_utils.normalize_data_format(data_format)
self.input_spec = InputSpec(min_ndim=2)
def call(self, inputs):
+ if self.data_format == 'channels_first':
+ permutation = [0]
+ permutation.extend([i for i in
+ range(2, K.ndim(inputs))])
+ permutation.append(1)
+ inputs = array_ops.transpose(inputs, perm=permutation)
+
outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1))
if not context.executing_eagerly():
outputs.set_shape(self.compute_output_shape(inputs.get_shape()))
@@ -534,6 +554,11 @@ class Flatten(Layer):
output_shape += [None]
return tensor_shape.TensorShape(output_shape)
+ def get_config(self):
+ config = {'data_format': self.data_format}
+ base_config = super(Flatten, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
@tf_export('keras.layers.RepeatVector')
class RepeatVector(Layer):
diff --git a/tensorflow/python/keras/_impl/keras/layers/core_test.py b/tensorflow/python/keras/_impl/keras/layers/core_test.py
index d22d8d12dc..9b360b65d6 100644
--- a/tensorflow/python/keras/_impl/keras/layers/core_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/core_test.py
@@ -124,6 +124,16 @@ class CoreLayersTest(test.TestCase):
testing_utils.layer_test(
keras.layers.Flatten, kwargs={}, input_shape=(3, 2, 4))
+ # Test channels_first
+ inputs = np.random.random((10, 3, 5, 5)).astype('float32')
+ outputs = testing_utils.layer_test(
+ keras.layers.Flatten,
+ kwargs={'data_format': 'channels_first'},
+ input_data=inputs)
+ target_outputs = np.reshape(
+ np.transpose(inputs, (0, 2, 3, 1)), (-1, 5 * 5 * 3))
+ self.assertAllClose(outputs, target_outputs)
+
@tf_test_util.run_in_graph_and_eager_modes()
def test_repeat_vector(self):
testing_utils.layer_test(
diff --git a/tensorflow/python/keras/_impl/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/cudnn_recurrent_test.py
index a06943b108..ad25eb226c 100644
--- a/tensorflow/python/keras/_impl/keras/layers/cudnn_recurrent_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/cudnn_recurrent_test.py
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import time
-
from absl.testing import parameterized
import numpy as np
@@ -33,43 +31,6 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer
class CuDNNTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes()
- def test_cudnn_rnn_timing(self):
- if test.is_gpu_available(cuda_only=True):
- with self.test_session(use_gpu=True):
- input_size = 10
- timesteps = 6
- units = 2
- num_samples = 32
-
- for rnn_type in ['lstm', 'gru']:
- times = []
- for use_cudnn in [True, False]:
- start_time = time.time()
- inputs = keras.layers.Input(shape=(None, input_size))
- if use_cudnn:
- if rnn_type == 'lstm':
- layer = keras.layers.CuDNNLSTM(units)
- else:
- layer = keras.layers.CuDNNGRU(units)
- else:
- if rnn_type == 'lstm':
- layer = keras.layers.LSTM(units)
- else:
- layer = keras.layers.GRU(units)
- outputs = layer(inputs)
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- model = keras.models.Model(inputs, outputs)
- model.compile(optimizer, 'mse')
-
- x = np.random.random((num_samples, timesteps, input_size))
- y = np.random.random((num_samples, units))
- model.fit(x, y, epochs=4, batch_size=32)
-
- times.append(time.time() - start_time)
- self.assertGreater(times[1], times[0])
-
- @test_util.run_in_graph_and_eager_modes()
def test_cudnn_rnn_basics(self):
if test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True):
diff --git a/tensorflow/python/keras/_impl/keras/layers/normalization_test.py b/tensorflow/python/keras/_impl/keras/layers/normalization_test.py
index fa9277e3d1..84f0b2776c 100644
--- a/tensorflow/python/keras/_impl/keras/layers/normalization_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/normalization_test.py
@@ -168,6 +168,76 @@ class NormalizationLayersTest(test.TestCase):
new_model.compile('sgd', 'mse')
new_model.train_on_batch(x, x)
+ def test_that_trainable_disables_updates(self):
+ with self.test_session():
+ val_a = np.random.random((10, 4))
+ val_out = np.random.random((10, 4))
+
+ a = keras.layers.Input(shape=(4,))
+ layer = keras.layers.BatchNormalization(input_shape=(4,))
+ b = layer(a)
+ model = keras.models.Model(a, b)
+
+ model.trainable = False
+ assert not model.updates
+
+ model.compile('sgd', 'mse')
+ assert not model.updates
+
+ x1 = model.predict(val_a)
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ self.assertAllClose(x1, x2, atol=1e-7)
+
+ model.trainable = True
+ model.compile('sgd', 'mse')
+ assert model.updates
+
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ assert np.abs(np.sum(x1 - x2)) > 1e-5
+
+ layer.trainable = False
+ model.compile('sgd', 'mse')
+ assert not model.updates
+
+ x1 = model.predict(val_a)
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ self.assertAllClose(x1, x2, atol=1e-7)
+
+ def test_batchnorm_trainable(self):
+ """Tests that batchnorm layer is trainable when learning phase is enabled.
+
+ Computes mean and std for current inputs then
+ applies batch normalization using them.
+ """
+ with self.test_session():
+ bn_mean = 0.5
+ bn_std = 10.
+ val_a = np.expand_dims(np.arange(10.), axis=1)
+
+ def get_model(bn_mean, bn_std):
+ inp = keras.layers.Input(shape=(1,))
+ x = keras.layers.BatchNormalization()(inp)
+ model1 = keras.models.Model(inp, x)
+ model1.set_weights([
+ np.array([1.]),
+ np.array([0.]),
+ np.array([bn_mean]),
+ np.array([bn_std**2])
+ ])
+ return model1
+
+ # Simulates training-mode with trainable layer.
+ # Should use mini-batch statistics.
+ keras.backend.set_learning_phase(1)
+ model = get_model(bn_mean, bn_std)
+ model.compile(loss='mse', optimizer='rmsprop')
+ out = model.predict(val_a)
+ self.assertAllClose(
+ (val_a - np.mean(val_a)) / np.std(val_a), out, atol=1e-3)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
index caf9e6f46f..93150b97fa 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
@@ -519,9 +519,10 @@ class RNN(Layer):
return [K.tile(initial_state, [1, self.cell.state_size])]
def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
- inputs, initial_state, constants = self._standardize_args(
- inputs, initial_state, constants)
-
+ inputs, initial_state, constants = _standardize_args(inputs,
+ initial_state,
+ constants,
+ self._num_constants)
if initial_state is None and constants is None:
return super(RNN, self).__call__(inputs, **kwargs)
@@ -661,46 +662,6 @@ class RNN(Layer):
else:
return output
- def _standardize_args(self, inputs, initial_state, constants):
- """Standardize `__call__` to a single list of tensor inputs.
-
- When running a model loaded from file, the input tensors
- `initial_state` and `constants` can be passed to `RNN.__call__` as part
- of `inputs` instead of by the dedicated keyword arguments. This method
- makes sure the arguments are separated and that `initial_state` and
- `constants` are lists of tensors (or None).
-
- Arguments:
- inputs: tensor or list/tuple of tensors
- initial_state: tensor or list of tensors or None
- constants: tensor or list of tensors or None
-
- Returns:
- inputs: tensor
- initial_state: list of tensors or None
- constants: list of tensors or None
- """
- if isinstance(inputs, list):
- assert initial_state is None and constants is None
- if self._num_constants is not None:
- constants = inputs[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
- inputs = inputs[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
- if len(inputs) > 1:
- initial_state = inputs[1:]
- inputs = inputs[0]
-
- def to_list_or_none(x):
- if x is None or isinstance(x, list):
- return x
- if isinstance(x, tuple):
- return list(x)
- return [x]
-
- initial_state = to_list_or_none(initial_state)
- constants = to_list_or_none(constants)
-
- return inputs, initial_state, constants
-
def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
@@ -914,13 +875,13 @@ class SimpleRNNCell(Layer):
prev_output = states[0]
if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
- _generate_dropout_ones(inputs, array_ops.shape(inputs)[-1]),
+ array_ops.ones_like(inputs),
self.dropout,
training=training)
if (0 < self.recurrent_dropout < 1 and
self._recurrent_dropout_mask is None):
self._recurrent_dropout_mask = _generate_dropout_mask(
- _generate_dropout_ones(inputs, self.units),
+ array_ops.ones_like(prev_output),
self.recurrent_dropout,
training=training)
@@ -1333,14 +1294,14 @@ class GRUCell(Layer):
if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
- _generate_dropout_ones(inputs, array_ops.shape(inputs)[-1]),
+ array_ops.ones_like(inputs),
self.dropout,
training=training,
count=3)
if (0 < self.recurrent_dropout < 1 and
self._recurrent_dropout_mask is None):
self._recurrent_dropout_mask = _generate_dropout_mask(
- _generate_dropout_ones(inputs, self.units),
+ array_ops.ones_like(h_tm1),
self.recurrent_dropout,
training=training,
count=3)
@@ -1873,14 +1834,14 @@ class LSTMCell(Layer):
def call(self, inputs, states, training=None):
if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
- _generate_dropout_ones(inputs, array_ops.shape(inputs)[-1]),
+ array_ops.ones_like(inputs),
self.dropout,
training=training,
count=4)
if (0 < self.recurrent_dropout < 1 and
self._recurrent_dropout_mask is None):
self._recurrent_dropout_mask = _generate_dropout_mask(
- _generate_dropout_ones(inputs, self.units),
+ array_ops.ones_like(states[0]),
self.recurrent_dropout,
training=training,
count=4)
@@ -2254,12 +2215,7 @@ class LSTM(RNN):
return cls(**config)
-def _generate_dropout_ones(inputs, dims):
- return K.ones((array_ops.shape(inputs)[0], dims))
-
-
def _generate_dropout_mask(ones, rate, training=None, count=1):
-
def dropped_inputs():
return K.dropout(ones, rate)
@@ -2605,3 +2561,47 @@ class Recurrent(Layer):
}
base_config = super(Recurrent, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+
+
+def _standardize_args(inputs, initial_state, constants, num_constants):
+ """Standardizes `__call__` to a single list of tensor inputs.
+
+ When running a model loaded from a file, the input tensors
+ `initial_state` and `constants` can be passed to `RNN.__call__()` as part
+ of `inputs` instead of by the dedicated keyword arguments. This method
+ makes sure the arguments are separated and that `initial_state` and
+ `constants` are lists of tensors (or None).
+
+ Arguments:
+ inputs: Tensor or list/tuple of tensors. which may include constants
+ and initial states. In that case `num_constant` must be specified.
+ initial_state: Tensor or list of tensors or None, initial states.
+ constants: Tensor or list of tensors or None, constant tensors.
+ num_constants: Expected number of constants (if constants are passed as
+ part of the `inputs` list.
+
+ Returns:
+ inputs: Single tensor.
+ initial_state: List of tensors or None.
+ constants: List of tensors or None.
+ """
+ if isinstance(inputs, list):
+ assert initial_state is None and constants is None
+ if num_constants is not None:
+ constants = inputs[-num_constants:]
+ inputs = inputs[:-num_constants]
+ if len(inputs) > 1:
+ initial_state = inputs[1:]
+ inputs = inputs[0]
+
+ def to_list_or_none(x):
+ if x is None or isinstance(x, list):
+ return x
+ if isinstance(x, tuple):
+ return list(x)
+ return [x]
+
+ initial_state = to_list_or_none(initial_state)
+ constants = to_list_or_none(constants)
+
+ return inputs, initial_state, constants
diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers.py b/tensorflow/python/keras/_impl/keras/layers/wrappers.py
index 91b8c1148b..d1d09bb4a2 100644
--- a/tensorflow/python/keras/_impl/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/_impl/keras/layers/wrappers.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.layers.recurrent import _standardize_args
from tensorflow.python.keras._impl.keras.utils import generic_utils
from tensorflow.python.keras._impl.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
@@ -284,6 +285,7 @@ class Bidirectional(Wrapper):
self.return_state = layer.return_state
self.supports_masking = True
self._trainable = True
+ self._num_constants = None
super(Bidirectional, self).__init__(layer, **kwargs)
self.input_spec = layer.input_spec
@@ -326,37 +328,51 @@ class Bidirectional(Wrapper):
return [output_shape] + state_shape + copy.copy(state_shape)
return output_shape
- def __call__(self, inputs, initial_state=None, **kwargs):
+ def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
+ """`Bidirectional.__call__` implements the same API as the wrapped `RNN`."""
+ inputs, initial_state, constants = _standardize_args(
+ inputs, initial_state, constants, self._num_constants)
+
if isinstance(inputs, list):
if len(inputs) > 1:
initial_state = inputs[1:]
inputs = inputs[0]
- if initial_state is None:
+ if initial_state is None and constants is None:
return super(Bidirectional, self).__call__(inputs, **kwargs)
- # Standardize `initial_state` into list
- if isinstance(initial_state, tuple):
- initial_state = list(initial_state)
- elif not isinstance(initial_state, list):
- initial_state = [initial_state]
-
- # Check if `initial_state` can be splitted into half
- num_states = len(initial_state)
- if num_states % 2 > 0:
- raise ValueError(
- 'When passing `initial_state` to a Bidirectional RNN, the state '
- 'should be a list containing the states of the underlying RNNs. '
- 'Found: ' + str(initial_state))
-
- # Applies the same workaround as in `RNN.__call__`, without handling
- # constants
- kwargs['initial_state'] = initial_state
- additional_inputs = initial_state
- additional_specs = [InputSpec(shape=K.int_shape(state))
- for state in initial_state]
- self.forward_layer.state_spec = additional_specs[:num_states // 2]
- self.backward_layer.state_spec = additional_specs[num_states // 2:]
+ # Applies the same workaround as in `RNN.__call__`
+ additional_inputs = []
+ additional_specs = []
+ if initial_state is not None:
+ # Check if `initial_state` can be splitted into half
+ num_states = len(initial_state)
+ if num_states % 2 > 0:
+ raise ValueError(
+ 'When passing `initial_state` to a Bidirectional RNN, '
+ 'the state should be a list containing the states of '
+ 'the underlying RNNs. '
+ 'Found: ' + str(initial_state))
+
+ kwargs['initial_state'] = initial_state
+ additional_inputs += initial_state
+ state_specs = [InputSpec(shape=K.int_shape(state))
+ for state in initial_state]
+ self.forward_layer.state_spec = state_specs[:num_states // 2]
+ self.backward_layer.state_spec = state_specs[num_states // 2:]
+ additional_specs += state_specs
+ if constants is not None:
+ kwargs['constants'] = constants
+ additional_inputs += constants
+ constants_spec = [InputSpec(shape=K.int_shape(constant))
+ for constant in constants]
+ self.forward_layer.constants_spec = constants_spec
+ self.backward_layer.constants_spec = constants_spec
+ additional_specs += constants_spec
+
+ self._num_constants = len(constants)
+ self.forward_layer._num_constants = self._num_constants
+ self.backward_layer._num_constants = self._num_constants
is_keras_tensor = K.is_keras_tensor(additional_inputs[0])
for tensor in additional_inputs:
@@ -381,12 +397,19 @@ class Bidirectional(Wrapper):
else:
return super(Bidirectional, self).__call__(inputs, **kwargs)
- def call(self, inputs, training=None, mask=None, initial_state=None):
+ def call(self, inputs,
+ training=None,
+ mask=None,
+ initial_state=None,
+ constants=None):
+ """`Bidirectional.call` implements the same API as the wrapped `RNN`."""
kwargs = {}
if generic_utils.has_arg(self.layer.call, 'training'):
kwargs['training'] = training
if generic_utils.has_arg(self.layer.call, 'mask'):
kwargs['mask'] = mask
+ if generic_utils.has_arg(self.layer.call, 'constants'):
+ kwargs['constants'] = constants
if initial_state is not None and generic_utils.has_arg(
self.layer.call, 'initial_state'):
@@ -444,13 +467,23 @@ class Bidirectional(Wrapper):
self.built = True
def compute_mask(self, inputs, mask):
+ if isinstance(mask, list):
+ mask = mask[0]
if self.return_sequences:
if not self.merge_mode:
- return [mask, mask]
+ output_mask = [mask, mask]
else:
- return mask
+ output_mask = mask
else:
- return None
+ output_mask = [None, None] if not self.merge_mode else None
+
+ if self.return_state:
+ states = self.forward_layer.states
+ state_mask = [None for _ in states]
+ if isinstance(output_mask, list):
+ return output_mask + state_mask * 2
+ return [output_mask] + state_mask * 2
+ return output_mask
@property
def trainable_weights(self):
@@ -488,5 +521,15 @@ class Bidirectional(Wrapper):
def get_config(self):
config = {'merge_mode': self.merge_mode}
+ if self._num_constants is not None:
+ config['num_constants'] = self._num_constants
base_config = super(Bidirectional, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ num_constants = config.pop('num_constants', None)
+ layer = super(Bidirectional, cls).from_config(config,
+ custom_objects=custom_objects)
+ layer._num_constants = num_constants
+ return layer
diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py b/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py
index 8fcf66e90f..05b272a470 100644
--- a/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import copy
+
import numpy as np
from tensorflow.python.framework import test_util as tf_test_util
@@ -26,6 +28,45 @@ from tensorflow.python.platform import test
from tensorflow.python.training.rmsprop import RMSPropOptimizer
+class _RNNCellWithConstants(keras.layers.Layer):
+
+ def __init__(self, units, **kwargs):
+ self.units = units
+ self.state_size = units
+ super(_RNNCellWithConstants, self).__init__(**kwargs)
+
+ def build(self, input_shape):
+ [input_shape, constant_shape] = input_shape
+
+ self.input_kernel = self.add_weight(
+ shape=(input_shape[-1], self.units),
+ initializer='uniform',
+ name='kernel')
+ self.recurrent_kernel = self.add_weight(
+ shape=(self.units, self.units),
+ initializer='uniform',
+ name='recurrent_kernel')
+ self.constant_kernel = self.add_weight(
+ shape=(constant_shape[-1], self.units),
+ initializer='uniform',
+ name='constant_kernel')
+ self.built = True
+
+ def call(self, inputs, states, constants):
+ [prev_output] = states
+ [constant] = constants
+ h_input = keras.backend.dot(inputs, self.input_kernel)
+ h_state = keras.backend.dot(prev_output, self.recurrent_kernel)
+ h_const = keras.backend.dot(constant, self.constant_kernel)
+ output = h_input + h_state + h_const
+ return output, [output]
+
+ def get_config(self):
+ config = {'units': self.units}
+ base_config = super(_RNNCellWithConstants, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+
class TimeDistributedTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes()
@@ -383,6 +424,100 @@ class BidirectionalTest(test.TestCase):
layer.trainable = True
assert len(layer.trainable_weights) == 6
+ def test_Bidirectional_with_constants(self):
+ with self.test_session():
+ # Test basic case.
+ x = keras.Input((5, 5))
+ c = keras.Input((3,))
+ cell = _RNNCellWithConstants(32)
+ custom_objects = {'_RNNCellWithConstants': _RNNCellWithConstants}
+ with keras.utils.CustomObjectScope(custom_objects):
+ layer = keras.layers.Bidirectional(keras.layers.RNN(cell))
+ y = layer(x, constants=c)
+ model = keras.Model([x, c], y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ [np.zeros((6, 5, 5)), np.zeros((6, 3))],
+ np.zeros((6, 64))
+ )
+
+ # Test basic case serialization.
+ x_np = np.random.random((6, 5, 5))
+ c_np = np.random.random((6, 3))
+ y_np = model.predict([x_np, c_np])
+ weights = model.get_weights()
+ config = layer.get_config()
+
+ with keras.utils.CustomObjectScope(custom_objects):
+ layer = keras.layers.Bidirectional.from_config(copy.deepcopy(config))
+ y = layer(x, constants=c)
+ model = keras.Model([x, c], y)
+ model.set_weights(weights)
+ y_np_2 = model.predict([x_np, c_np])
+ self.assertAllClose(y_np, y_np_2, atol=1e-4)
+
+ # Test flat list inputs
+ with keras.utils.CustomObjectScope(custom_objects):
+ layer = keras.layers.Bidirectional.from_config(copy.deepcopy(config))
+ y = layer([x, c])
+ model = keras.Model([x, c], y)
+ model.set_weights(weights)
+ y_np_3 = model.predict([x_np, c_np])
+ self.assertAllClose(y_np, y_np_3, atol=1e-4)
+
+ def test_Bidirectional_with_constants_layer_passing_initial_state(self):
+ with self.test_session():
+ # Test basic case.
+ x = keras.Input((5, 5))
+ c = keras.Input((3,))
+ s_for = keras.Input((32,))
+ s_bac = keras.Input((32,))
+ cell = _RNNCellWithConstants(32)
+ custom_objects = {'_RNNCellWithConstants': _RNNCellWithConstants}
+ with keras.utils.CustomObjectScope(custom_objects):
+ layer = keras.layers.Bidirectional(keras.layers.RNN(cell))
+ y = layer(x, initial_state=[s_for, s_bac], constants=c)
+ model = keras.Model([x, s_for, s_bac, c], y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ [np.zeros((6, 5, 5)),
+ np.zeros((6, 32)),
+ np.zeros((6, 32)),
+ np.zeros((6, 3))],
+ np.zeros((6, 64))
+ )
+
+ # Test basic case serialization.
+ x_np = np.random.random((6, 5, 5))
+ s_fw_np = np.random.random((6, 32))
+ s_bk_np = np.random.random((6, 32))
+ c_np = np.random.random((6, 3))
+ y_np = model.predict([x_np, s_fw_np, s_bk_np, c_np])
+ weights = model.get_weights()
+ config = layer.get_config()
+
+ with keras.utils.CustomObjectScope(custom_objects):
+ layer = keras.layers.Bidirectional.from_config(copy.deepcopy(config))
+ y = layer(x, initial_state=[s_for, s_bac], constants=c)
+ model = keras.Model([x, s_for, s_bac, c], y)
+ model.set_weights(weights)
+ y_np_2 = model.predict([x_np, s_fw_np, s_bk_np, c_np])
+ self.assertAllClose(y_np, y_np_2, atol=1e-4)
+
+ # Verify that state is used
+ y_np_2_different_s = model.predict(
+ [x_np, s_fw_np + 10., s_bk_np + 10., c_np])
+ assert np.mean(y_np - y_np_2_different_s) != 0
+
+ # Test flat list inputs
+ with keras.utils.CustomObjectScope(custom_objects):
+ layer = keras.layers.Bidirectional.from_config(copy.deepcopy(config))
+ y = layer([x, s_for, s_bac, c])
+ model = keras.Model([x, s_for, s_bac, c], y)
+ model.set_weights(weights)
+ y_np_3 = model.predict([x_np, s_fw_np, s_bk_np, c_np])
+ self.assertAllClose(y_np, y_np_3, atol=1e-4)
+
def _to_list(ls):
if isinstance(ls, list):
diff --git a/tensorflow/python/keras/_impl/keras/metrics_test.py b/tensorflow/python/keras/_impl/keras/metrics_test.py
index 13cef97812..819bf60256 100644
--- a/tensorflow/python/keras/_impl/keras/metrics_test.py
+++ b/tensorflow/python/keras/_impl/keras/metrics_test.py
@@ -92,6 +92,7 @@ class KerasMetricsTest(test.TestCase):
def __init__(self, name='true_positives', **kwargs):
super(BinaryTruePositives, self).__init__(name=name, **kwargs)
self.true_positives = keras.backend.variable(value=0, dtype='int32')
+ self.stateful = True
def reset_states(self):
keras.backend.set_value(self.true_positives, 0)
@@ -132,10 +133,17 @@ class KerasMetricsTest(test.TestCase):
metrics=['acc', metric_fn])
# Test fit, evaluate
- samples = 1000
+ samples = 100
x = np.random.random((samples, 2))
y = np.random.randint(2, size=(samples, 1))
- model.fit(x, y, epochs=1, batch_size=10)
+ val_samples = 10
+ val_x = np.random.random((val_samples, 2))
+ val_y = np.random.randint(2, size=(val_samples, 1))
+
+ history = model.fit(x, y,
+ epochs=1,
+ batch_size=10,
+ validation_data=(val_x, val_y))
outs = model.evaluate(x, y, batch_size=10)
preds = model.predict(x)
@@ -145,6 +153,37 @@ class KerasMetricsTest(test.TestCase):
# Test correctness (e.g. updates should have been run)
self.assertAllClose(outs[2], ref_true_pos(y, preds), atol=1e-5)
+ # Test correctness of the validation metric computation
+ val_preds = model.predict(val_x)
+ val_outs = model.evaluate(val_x, val_y, batch_size=10)
+ self.assertAllClose(
+ val_outs[2], ref_true_pos(val_y, val_preds), atol=1e-5)
+ self.assertAllClose(
+ val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)
+
+ # Test with generators
+ gen = [(np.array([x0]), np.array([y0])) for x0, y0 in zip(x, y)]
+ val_gen = [(np.array([x0]), np.array([y0]))
+ for x0, y0 in zip(val_x, val_y)]
+ history = model.fit_generator(iter(gen),
+ epochs=1,
+ steps_per_epoch=samples,
+ validation_data=iter(val_gen),
+ validation_steps=val_samples)
+ outs = model.evaluate_generator(iter(gen), steps=samples)
+ preds = model.predict_generator(iter(gen), steps=samples)
+
+ # Test correctness of the metric results
+ self.assertAllClose(outs[2], ref_true_pos(y, preds), atol=1e-5)
+
+ # Test correctness of the validation metric computation
+ val_preds = model.predict_generator(iter(val_gen), steps=val_samples)
+ val_outs = model.evaluate_generator(iter(val_gen), steps=val_samples)
+ self.assertAllClose(
+ val_outs[2], ref_true_pos(val_y, val_preds), atol=1e-5)
+ self.assertAllClose(
+ val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image.py b/tensorflow/python/keras/_impl/keras/preprocessing/image.py
index 6299445c34..5dfbf0fca5 100644
--- a/tensorflow/python/keras/_impl/keras/preprocessing/image.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/image.py
@@ -217,6 +217,16 @@ def random_zoom(x,
@tf_export('keras.preprocessing.image.random_channel_shift')
def random_channel_shift(x, intensity, channel_axis=0):
+ """Perform a random channel shift.
+
+ Arguments:
+ x: Input tensor. Must be 3D.
+ intensity: Transformation intensity.
+ channel_axis: Index of axis for channels in the input tensor.
+
+ Returns:
+ Numpy image tensor.
+ """
x = np.rollaxis(x, channel_axis, 0)
min_x, max_x = np.min(x), np.max(x)
channel_images = [
@@ -451,54 +461,149 @@ def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm'):
@tf_export('keras.preprocessing.image.ImageDataGenerator')
class ImageDataGenerator(object):
- """Generate minibatches of image data with real-time data augmentation.
+ """Generates batches of tensor image data with real-time data augmentation.
+ The data will be looped over (in batches).
Arguments:
- featurewise_center: set input mean to 0 over the dataset.
- samplewise_center: set each sample mean to 0.
- featurewise_std_normalization: divide inputs by std of the dataset.
- samplewise_std_normalization: divide each input by its std.
- zca_whitening: apply ZCA whitening.
+ featurewise_center: boolean, set input mean to 0 over the dataset,
+ feature-wise.
+ samplewise_center: boolean, set each sample mean to 0.
+ featurewise_std_normalization: boolean, divide inputs by std
+ of the dataset, feature-wise.
+ samplewise_std_normalization: boolean, divide each input by its std.
zca_epsilon: epsilon for ZCA whitening. Default is 1e-6.
- rotation_range: degrees (0 to 180).
- width_shift_range: fraction of total width, if < 1, or pixels if >= 1.
- height_shift_range: fraction of total height, if < 1, or pixels if >= 1.
- brightness_range: the range of brightness to apply
- shear_range: shear intensity (shear angle in degrees).
- zoom_range: amount of zoom. if scalar z, zoom will be randomly picked
- in the range [1-z, 1+z]. A sequence of two can be passed instead
- to select this range.
- channel_shift_range: shift range for each channel.
- fill_mode: points outside the boundaries are filled according to the
- given mode ('constant', 'nearest', 'reflect' or 'wrap'). Default
- is 'nearest'.
- Points outside the boundaries of the input are filled according to the
- given mode:
+ zca_whitening: boolean, apply ZCA whitening.
+ rotation_range: int, degree range for random rotations.
+ width_shift_range: float, 1-D array-like or int
+ float: fraction of total width, if < 1, or pixels if >= 1.
+ 1-D array-like: random elements from the array.
+ int: integer number of pixels from interval
+ `(-width_shift_range, +width_shift_range)`
+ With `width_shift_range=2` possible values are integers [-1, 0, +1],
+ same as with `width_shift_range=[-1, 0, +1]`,
+ while with `width_shift_range=1.0` possible values are floats in
+ the interval [-1.0, +1.0).
+ shear_range: float, shear Intensity
+ (Shear angle in counter-clockwise direction in degrees)
+ zoom_range: float or [lower, upper], Range for random zoom.
+ If a float, `[lower, upper] = [1-zoom_range, 1+zoom_range]`.
+ channel_shift_range: float, range for random channel shifts.
+ fill_mode: One of {"constant", "nearest", "reflect" or "wrap"}.
+ Default is 'nearest'. Points outside the boundaries of the input
+ are filled according to the given mode:
'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
'nearest': aaaaaaaa|abcd|dddddddd
'reflect': abcddcba|abcd|dcbaabcd
'wrap': abcdabcd|abcd|abcdabcd
- cval: value used for points outside the boundaries when fill_mode is
- 'constant'. Default is 0.
- horizontal_flip: whether to randomly flip images horizontally.
- vertical_flip: whether to randomly flip images vertically.
- rescale: rescaling factor. If None or 0, no rescaling is applied,
- otherwise we multiply the data by the value provided. This is
- applied after the `preprocessing_function` (if any provided)
- but before any other transformation.
+ cval: float or int, value used for points outside the boundaries
+ when `fill_mode = "constant"`.
+ horizontal_flip: boolean, randomly flip inputs horizontally.
+ vertical_flip: boolean, randomly flip inputs vertically.
+ rescale: rescaling factor. Defaults to None. If None or 0, no rescaling
+ is applied, otherwise we multiply the data by the value provided
+ (before applying any other transformation).
preprocessing_function: function that will be implied on each input.
- The function will run before any other modification on it.
+ The function will run after the image is resized and augmented.
The function should take one argument:
one image (Numpy tensor with rank 3),
and should output a Numpy tensor with the same shape.
- data_format: 'channels_first' or 'channels_last'. In 'channels_first'
- mode, the channels dimension
- (the depth) is at index 1, in 'channels_last' mode it is at index 3.
+ data_format: One of {"channels_first", "channels_last"}.
+ "channels_last" mode means that the images should have shape
+ `(samples, height, width, channels)`,
+ "channels_first" mode means that the images should have shape
+ `(samples, channels, height, width)`.
It defaults to the `image_data_format` value found in your
- Keras config file at `~/.keras/keras.json`.
+ Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
- validation_split: fraction of images reserved for validation (strictly
- between 0 and 1).
+ validation_split: float, fraction of images reserved for validation
+ (strictly between 0 and 1).
+
+ Examples:
+ Example of using `.flow(x, y)`:
+ ```python
+ (x_train, y_train), (x_test, y_test) = cifar10.load_data()
+ y_train = np_utils.to_categorical(y_train, num_classes)
+ y_test = np_utils.to_categorical(y_test, num_classes)
+ datagen = ImageDataGenerator(
+ featurewise_center=True,
+ featurewise_std_normalization=True,
+ rotation_range=20,
+ width_shift_range=0.2,
+ height_shift_range=0.2,
+ horizontal_flip=True)
+ # compute quantities required for featurewise normalization
+ # (std, mean, and principal components if ZCA whitening is applied)
+ datagen.fit(x_train)
+ # fits the model on batches with real-time data augmentation:
+ model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
+ steps_per_epoch=len(x_train) / 32, epochs=epochs)
+ # here's a more "manual" example
+ for e in range(epochs):
+ print('Epoch', e)
+ batches = 0
+ for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
+ model.fit(x_batch, y_batch)
+ batches += 1
+ if batches >= len(x_train) / 32:
+ # we need to break the loop by hand because
+ # the generator loops indefinitely
+ break
+ ```
+ Example of using `.flow_from_directory(directory)`:
+ ```python
+ train_datagen = ImageDataGenerator(
+ rescale=1./255,
+ shear_range=0.2,
+ zoom_range=0.2,
+ horizontal_flip=True)
+ test_datagen = ImageDataGenerator(rescale=1./255)
+ train_generator = train_datagen.flow_from_directory(
+ 'data/train',
+ target_size=(150, 150),
+ batch_size=32,
+ class_mode='binary')
+ validation_generator = test_datagen.flow_from_directory(
+ 'data/validation',
+ target_size=(150, 150),
+ batch_size=32,
+ class_mode='binary')
+ model.fit_generator(
+ train_generator,
+ steps_per_epoch=2000,
+ epochs=50,
+ validation_data=validation_generator,
+ validation_steps=800)
+ ```
+ Example of transforming images and masks together.
+ ```python
+ # we create two instances with the same arguments
+ data_gen_args = dict(featurewise_center=True,
+ featurewise_std_normalization=True,
+ rotation_range=90.,
+ width_shift_range=0.1,
+ height_shift_range=0.1,
+ zoom_range=0.2)
+ image_datagen = ImageDataGenerator(**data_gen_args)
+ mask_datagen = ImageDataGenerator(**data_gen_args)
+ # Provide the same seed and keyword arguments to the fit and flow methods
+ seed = 1
+ image_datagen.fit(images, augment=True, seed=seed)
+ mask_datagen.fit(masks, augment=True, seed=seed)
+ image_generator = image_datagen.flow_from_directory(
+ 'data/images',
+ class_mode=None,
+ seed=seed)
+ mask_generator = mask_datagen.flow_from_directory(
+ 'data/masks',
+ class_mode=None,
+ seed=seed)
+ # combine generators into one which yields image and masks
+ train_generator = zip(image_generator, mask_generator)
+ model.fit_generator(
+ train_generator,
+ steps_per_epoch=2000,
+ epochs=50)
+ ```
"""
def __init__(self,
@@ -613,6 +718,31 @@ class ImageDataGenerator(object):
save_prefix='',
save_format='png',
subset=None):
+ """Generates batches of augmented/normalized data with given numpy arrays.
+
+ Arguments:
+ x: data. Should have rank 4.
+ In case of grayscale data, the channels axis should have value 1
+ and in case of RGB data, it should have value 3.
+ y: labels.
+ batch_size: int (default: 32).
+ shuffle: boolean (default: True).
+ seed: int (default: None).
+ save_to_dir: None or str (default: None).
+ This allows you to optionally specify a directory
+ to which to save the augmented pictures being generated
+ (useful for visualizing what you are doing).
+ save_prefix: str (default: `''`). Prefix to use for filenames of
+ saved pictures (only relevant if `save_to_dir` is set).
+ save_format: one of "png", "jpeg". Default: "png".
+ (only relevant if `save_to_dir` is set)
+ subset: Subset of data (`"training"` or `"validation"`) if
+ `validation_split` is set in `ImageDataGenerator`.
+
+ Returns:
+ An Iterator yielding tuples of `(x, y)` where `x` is a numpy array of
+ image data and `y` is a numpy array of corresponding labels.
+ """
return NumpyArrayIterator(
x,
y,
@@ -641,6 +771,65 @@ class ImageDataGenerator(object):
follow_links=False,
subset=None,
interpolation='nearest'):
+ """Generates batches of augmented/normalized data given directory path.
+
+ Arguments:
+ directory: path to the target directory. It should contain one
+ subdirectory per class. Any PNG, JPG, BMP, PPM or TIF images
+ inside each of the subdirectories directory tree will be included
+ in the generator. See [this script]
+ (https://gist.github.com/fchollet/0830affa1f7f19fd47b06d4cf89ed44d)
+ for more details.
+ target_size: tuple of integers `(height, width)`, default: `(256,
+ 256)`. The dimensions to which all images found will be resized.
+ color_mode: one of "grayscale", "rbg". Default: "rgb". Whether the
+ images will be converted to have 1 or 3 color channels.
+ classes: optional list of class subdirectories (e.g. `['dogs',
+ 'cats']`). Default: None. If not provided, the list of classes
+ will be automatically inferred from the subdirectory
+ names/structure under `directory`, where each subdirectory will be
+ treated as a different class (and the order of the classes, which
+ will map to the label indices, will be alphanumeric). The
+ dictionary containing the mapping from class names to class
+ indices can be obtained via the attribute `class_indices`.
+ class_mode: one of "categorical", "binary", "sparse", "input" or
+ None. Default: "categorical". Determines the type of label arrays
+ that are returned: "categorical" will be 2D one-hot encoded
+ labels, "binary" will be 1D binary labels, "sparse" will be 1D
+ integer labels, "input" will be images identical to input images
+ (mainly used to work with autoencoders). If None, no labels are
+ returned (the generator will only yield batches of image data,
+ which is useful to use `model.predict_generator()`,
+ `model.evaluate_generator()`, etc.). Please note that in case of
+ class_mode None, the data still needs to reside in a subdirectory
+ of `directory` for it to work correctly.
+ batch_size: size of the batches of data (default: 32).
+ shuffle: whether to shuffle the data (default: True)
+ seed: optional random seed for shuffling and transformations.
+ save_to_dir: None or str (default: None). This allows you to
+ optionally specify a directory to which to save the augmented
+ pictures being generated (useful for visualizing what you are doing)
+ save_prefix: str. Prefix to use for filenames of saved pictures
+ (only relevant if `save_to_dir` is set).
+ save_format: one of "png", "jpeg" (only relevant if `save_to_dir` is
+ set). Default: "png".
+ follow_links: whether to follow symlinks inside class subdirectories
+ (default: False).
+ subset: Subset of data (`"training"` or `"validation"`) if
+ ` validation_split` is set in `ImageDataGenerator`.
+ interpolation: Interpolation method used to resample the image if
+ the target size is different from that of the loaded image.
+ Supported methods are `"nearest"`, `"bilinear"`, and `"bicubic"`.
+ If PIL version 1.1.3 or newer is installed, `"lanczos"` is also
+ supported. If PIL version 3.4.0 or newer is installed, `"box"` and
+ `"hamming"` are also supported. By default, `"nearest"` is used.
+
+ Returns:
+ A DirectoryIterator yielding tuples of `(x, y)` where `x` is a
+ numpy array containing a batch of images with shape
+ `(batch_size, *target_size, channels)` and `y` is a numpy
+ array of corresponding labels.
+ """
return DirectoryIterator(
directory,
self,
@@ -669,7 +858,7 @@ class ImageDataGenerator(object):
The inputs, normalized.
"""
if self.preprocessing_function:
- x = self.image_data_generator.preprocessing_function(x)
+ x = self.preprocessing_function(x)
if self.rescale:
x *= self.rescale
if self.samplewise_center:
@@ -737,15 +926,24 @@ class ImageDataGenerator(object):
theta = 0
if self.height_shift_range:
- tx = np.random.uniform(-self.height_shift_range, self.height_shift_range)
- if self.height_shift_range < 1:
+ try: # 1-D array-like or int
+ tx = np.random.choice(self.height_shift_range)
+ tx *= np.random.choice([-1, 1])
+ except ValueError: # floating point
+ tx = np.random.uniform(-self.height_shift_range,
+ self.height_shift_range)
+ if np.max(self.height_shift_range) < 1:
tx *= x.shape[img_row_axis]
else:
tx = 0
if self.width_shift_range:
- ty = np.random.uniform(-self.width_shift_range, self.width_shift_range)
- if self.width_shift_range < 1:
+ try: # 1-D array-like or int
+ ty = np.random.choice(self.width_shift_range)
+ ty *= np.random.choice([-1, 1])
+ except ValueError: # floating point
+ ty = np.random.uniform(-self.width_shift_range, self.width_shift_range)
+ if np.max(self.width_shift_range) < 1:
ty *= x.shape[img_col_axis]
else:
ty = 0
@@ -809,24 +1007,25 @@ class ImageDataGenerator(object):
return x
def fit(self, x, augment=False, rounds=1, seed=None):
- """Fits internal statistics to some sample data.
+ """Computes the internal data statistics based on an array of sample data.
- Required for featurewise_center, featurewise_std_normalization
- and zca_whitening.
+ These are statistics related to the data-dependent transformations.
+ Only required if featurewise_center or featurewise_std_normalization or
+ zca_whitening.
Arguments:
- x: Numpy array, the data to fit on. Should have rank 4.
- In case of grayscale data,
- the channels axis should have value 1, and in case
- of RGB data, it should have value 3.
- augment: Whether to fit on randomly augmented samples
- rounds: If `augment`,
- how many augmentation passes to do over the data
- seed: random seed.
+ x: sample data. Should have rank 4.
+ In case of grayscale data, the channels axis should have value 1
+ and in case of RGB data, it should have value 3.
+ augment: Boolean (default: False). Whether to fit on randomly
+ augmented samples.
+ rounds: int (default: 1). If augment, how many augmentation passes
+ over the data to use.
+ seed: int (default: None). Random seed.
Raises:
- ValueError: in case of invalid input `x`.
- ImportError: if Scipy is not available.
+ ValueError: If input rank is not 4.
+ ImportError: If scipy is not imported.
"""
x = np.asarray(x, dtype=K.floatx())
if x.ndim != 4:
diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py
index 001fee91f9..d2e8ac10ae 100644
--- a/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py
@@ -246,7 +246,37 @@ class TestImage(test.TestCase):
self.assertEqual(len(dir_iterator.class_indices), num_classes)
self.assertEqual(len(dir_iterator.classes), count)
self.assertEqual(set(dir_iterator.filenames), set(filenames))
- _ = dir_iterator.next()
+
+ def preprocessing_function(x):
+ """This will fail if not provided by a Numpy array.
+
+ Note: This is made to enforce backward compatibility.
+
+ Args:
+ x: A numpy array.
+
+ Returns:
+ An array of zeros with the same shape as the given array.
+ """
+ self.assertEqual(x.shape, (26, 26, 3))
+ self.assertIs(type(x), np.ndarray)
+ return np.zeros_like(x)
+
+ # Test usage as Sequence
+ generator = keras.preprocessing.image.ImageDataGenerator(
+ preprocessing_function=preprocessing_function)
+ dir_seq = generator.flow_from_directory(
+ str(temp_dir),
+ target_size=(26, 26),
+ color_mode='rgb',
+ batch_size=3,
+ class_mode='categorical')
+ self.assertEqual(len(dir_seq), count // 3 + 1)
+ x1, y1 = dir_seq[1]
+ self.assertEqual(x1.shape, (3, 26, 26, 3))
+ self.assertEqual(y1.shape, (3, num_classes))
+ x1, y1 = dir_seq[5]
+ self.assertTrue((x1 == 0).all())
def directory_iterator_with_validation_split_test_helper(
self, validation_split):
diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py
index e68c171d9c..49bb0b957a 100644
--- a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py
@@ -357,9 +357,15 @@ class TimeseriesGenerator(Sequence):
self.reverse = reverse
self.batch_size = batch_size
+ if self.start_index > self.end_index:
+ raise ValueError('`start_index+length=%i > end_index=%i` '
+ 'is disallowed, as no part of the sequence '
+ 'would be left to be used as current step.' %
+ (self.start_index, self.end_index))
+
def __len__(self):
length = int(
- np.ceil((self.end_index - self.start_index) /
+ np.ceil((self.end_index - self.start_index + 1) /
(self.batch_size * self.stride)))
return length if length >= 0 else 0
@@ -373,11 +379,12 @@ class TimeseriesGenerator(Sequence):
def __getitem__(self, index):
if self.shuffle:
rows = np.random.randint(
- self.start_index, self.end_index, size=self.batch_size)
+ self.start_index, self.end_index + 1, size=self.batch_size)
else:
i = self.start_index + self.batch_size * self.stride * index
- rows = np.arange(i, min(i + self.batch_size * self.stride,
- self.end_index), self.stride)
+ rows = np.arange(
+ i, min(i + self.batch_size * self.stride, self.end_index + 1),
+ self.stride)
samples, targets = self._empty_batch(len(rows))
for j in range(len(rows)):
diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py
index b9bfdd0004..0e7045f517 100644
--- a/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from math import ceil
+
import numpy as np
from tensorflow.python.keras._impl import keras
@@ -146,7 +148,7 @@ class TestSequence(test.TestCase):
start_index=10,
end_index=30,
batch_size=2)
- self.assertEqual(len(data_gen), 5)
+ self.assertEqual(len(data_gen), 6)
self.assertAllClose(data_gen[0][0],
np.array([[[10], [12], [14], [16], [18]],
[[11], [13], [15], [17], [19]]]))
@@ -163,13 +165,74 @@ class TestSequence(test.TestCase):
end_index=30,
batch_size=2)
- self.assertEqual(len(data_gen), 5)
+ self.assertEqual(len(data_gen), 6)
self.assertAllClose(data_gen[0][0],
np.array(
[np.array(data[10:19:2]),
np.array(data[11:20:2])]))
self.assertAllClose(data_gen[0][1], np.array([targets[20], targets[21]]))
+ with self.assertRaises(ValueError) as context:
+ keras.preprocessing.sequence.TimeseriesGenerator(data, targets, length=50)
+ error = str(context.exception)
+ self.assertIn('`start_index+length=50 > end_index=49` is disallowed', error)
+
+ def test_TimeSeriesGenerator_doesnt_miss_any_sample(self):
+ x = np.array([[i] for i in range(10)])
+
+ for length in range(3, 10):
+ g = keras.preprocessing.sequence.TimeseriesGenerator(
+ x, x, length=length, batch_size=1)
+ expected = max(0, len(x) - length)
+ actual = len(g)
+ self.assertEqual(expected, actual)
+
+ if actual > 0:
+ # All elements in range(length, 10) should be used as current step
+ expected = np.arange(length, 10).reshape(-1, 1)
+
+ y = np.concatenate([g[ix][1] for ix in range(len(g))], axis=0)
+ self.assertAllClose(y, expected)
+
+ x = np.array([[i] for i in range(23)])
+
+ strides = (1, 1, 5, 7, 3, 5, 3)
+ lengths = (3, 3, 4, 3, 1, 3, 7)
+ batch_sizes = (6, 6, 6, 5, 6, 6, 6)
+ shuffles = (False, True, True, False, False, False, False)
+
+ for stride, length, batch_size, shuffle in zip(strides, lengths,
+ batch_sizes, shuffles):
+ g = keras.preprocessing.sequence.TimeseriesGenerator(
+ x,
+ x,
+ length=length,
+ sampling_rate=1,
+ stride=stride,
+ start_index=0,
+ end_index=None,
+ shuffle=shuffle,
+ reverse=False,
+ batch_size=batch_size)
+ if shuffle:
+ # all batches have the same size when shuffle is True.
+ expected_sequences = ceil(
+ (23 - length) / float(batch_size * stride)) * batch_size
+ else:
+ # last batch will be different if `(samples - length) / stride`
+ # is not a multiple of `batch_size`.
+ expected_sequences = ceil((23 - length) / float(stride))
+
+ expected_batches = ceil(expected_sequences / float(batch_size))
+
+ y = [g[ix][1] for ix in range(len(g))]
+
+ actual_sequences = sum(len(iy) for iy in y)
+ actual_batches = len(y)
+
+ self.assertEqual(expected_sequences, actual_sequences)
+ self.assertEqual(expected_batches, actual_batches)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/text.py b/tensorflow/python/keras/_impl/keras/preprocessing/text.py
index f652f318f3..f3b57de257 100644
--- a/tensorflow/python/keras/_impl/keras/preprocessing/text.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/text.py
@@ -42,13 +42,15 @@ def text_to_word_sequence(text,
filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
lower=True,
split=' '):
- """Converts a text to a sequence of words (or tokens).
+ r"""Converts a text to a sequence of words (or tokens).
Arguments:
text: Input text (string).
- filters: Sequence of characters to filter out.
- lower: Whether to convert the input to lowercase.
- split: Sentence split marker (string).
+ filters: list (or concatenation) of characters to filter out, such as
+ punctuation. Default: '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
+ includes basic punctuation, tabs, and newlines.
+ lower: boolean, whether to convert the input to lowercase.
+ split: string, separator for word splitting.
Returns:
A list of words (or tokens).
@@ -56,12 +58,21 @@ def text_to_word_sequence(text,
if lower:
text = text.lower()
- if sys.version_info < (3,) and isinstance(text, unicode):
- translate_map = dict((ord(c), unicode(split)) for c in filters)
+ if sys.version_info < (3,):
+ if isinstance(text, unicode):
+ translate_map = dict((ord(c), unicode(split)) for c in filters)
+ text = text.translate(translate_map)
+ elif len(split) == 1:
+ translate_map = maketrans(filters, split * len(filters))
+ text = text.translate(translate_map)
+ else:
+ for c in filters:
+ text = text.replace(c, split)
else:
- translate_map = maketrans(filters, split * len(filters))
+ translate_dict = dict((c, split) for c in filters)
+ translate_map = maketrans(translate_dict)
+ text = text.translate(translate_map)
- text = text.translate(translate_map)
seq = text.split(split)
return [i for i in seq if i]
@@ -72,20 +83,23 @@ def one_hot(text,
filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
lower=True,
split=' '):
- """One-hot encodes a text into a list of word indexes of size n.
+ r"""One-hot encodes a text into a list of word indexes of size n.
This is a wrapper to the `hashing_trick` function using `hash` as the
hashing function; unicity of word to index mapping non-guaranteed.
Arguments:
text: Input text (string).
- n: Dimension of the hashing space.
- filters: Sequence of characters to filter out.
- lower: Whether to convert the input to lowercase.
- split: Sentence split marker (string).
+ n: int, size of vocabulary.
+ filters: list (or concatenation) of characters to filter out, such as
+ punctuation. Default: '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
+ includes basic punctuation, tabs, and newlines.
+ lower: boolean, whether to set the text to lowercase.
+ split: string, separator for word splitting.
Returns:
- A list of integer word indices (unicity non-guaranteed).
+ List of integers in [1, n].
+ Each integer encodes a word (unicity non-guaranteed).
"""
return hashing_trick(
text, n, hash_function=hash, filters=filters, lower=lower, split=split)
@@ -98,19 +112,21 @@ def hashing_trick(text,
filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
lower=True,
split=' '):
- """Converts a text to a sequence of indexes in a fixed-size hashing space.
+ r"""Converts a text to a sequence of indexes in a fixed-size hashing space.
Arguments:
text: Input text (string).
n: Dimension of the hashing space.
- hash_function: if `None` uses python `hash` function, can be 'md5' or
+ hash_function: defaults to python `hash` function, can be 'md5' or
any function that takes in input a string and returns a int.
- Note that `hash` is not a stable hashing function, so
+ Note that 'hash' is not a stable hashing function, so
it is not consistent across different runs, while 'md5'
is a stable hashing function.
- filters: Sequence of characters to filter out.
- lower: Whether to convert the input to lowercase.
- split: Sentence split marker (string).
+ filters: list (or concatenation) of characters to filter out, such as
+ punctuation. Default: '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
+ includes basic punctuation, tabs, and newlines.
+ lower: boolean, whether to set the text to lowercase.
+ split: string, separator for word splitting.
Returns:
A list of integer word indices (unicity non-guaranteed).
@@ -150,7 +166,7 @@ class Tokenizer(object):
filtered from the texts. The default is all punctuation, plus
tabs and line breaks, minus the `'` character.
lower: boolean. Whether to convert the texts to lowercase.
- split: character or string to use for token splitting.
+ split: string, separator for word splitting.
char_level: if True, every character will be treated as a token.
oov_token: if given, it will be added to word_index and used to
replace out-of-vocabulary words during text_to_sequence calls
diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py
index c6a267e57e..6cdc0a70cc 100644
--- a/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py
@@ -114,11 +114,21 @@ class TestText(test.TestCase):
seq = keras.preprocessing.text.text_to_word_sequence(text)
self.assertEqual(seq, ['hello', 'world'])
+ def test_text_to_word_sequence_multichar_split(self):
+ text = 'hello!stop?world!'
+ seq = keras.preprocessing.text.text_to_word_sequence(text, split='stop')
+ self.assertEqual(seq, ['hello', 'world'])
+
def test_text_to_word_sequence_unicode(self):
text = u'ali! veli? kırk dokuz elli'
seq = keras.preprocessing.text.text_to_word_sequence(text)
self.assertEqual(seq, [u'ali', u'veli', u'kırk', u'dokuz', u'elli'])
+ def test_text_to_word_sequence_unicode_multichar_split(self):
+ text = u'ali!stopveli?stopkırkstopdokuzstopelli'
+ seq = keras.preprocessing.text.text_to_word_sequence(text, split='stop')
+ self.assertEqual(seq, [u'ali', u'veli', u'kırk', u'dokuz', u'elli'])
+
def test_tokenizer_unicode(self):
texts = [
u'ali veli kırk dokuz elli', u'ali veli kırk dokuz elli veli kırk dokuz'
diff --git a/tensorflow/python/keras/_impl/keras/testing_utils.py b/tensorflow/python/keras/_impl/keras/testing_utils.py
index 60799ee1e0..b8172064c3 100644
--- a/tensorflow/python/keras/_impl/keras/testing_utils.py
+++ b/tensorflow/python/keras/_impl/keras/testing_utils.py
@@ -37,7 +37,6 @@ def get_test_data(train_samples,
test_samples: Integer, how many test samples to generate.
input_shape: Tuple of integers, shape of the inputs.
num_classes: Integer, number of classes for the data and targets.
- Only relevant if `classification=True`.
Returns:
A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
@@ -45,7 +44,7 @@ def get_test_data(train_samples,
num_sample = train_samples + test_samples
templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)
y = np.random.randint(0, num_classes, size=(num_sample,))
- x = np.zeros((num_sample,) + input_shape)
+ x = np.zeros((num_sample,) + input_shape, dtype=np.float32)
for i in range(num_sample):
x[i] = templates[y[i]] + np.random.normal(loc=0, scale=1., size=input_shape)
return ((x[:train_samples], y[:train_samples]),
diff --git a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
index db184d278c..a69893955f 100644
--- a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
@@ -349,7 +349,10 @@ class Progbar(object):
self._values[k][0] += v * (current - self._seen_so_far)
self._values[k][1] += (current - self._seen_so_far)
else:
- self._values[k] = v
+ # Stateful metrics output a numeric value. This representation
+ # means "take an average from a single value" but keeps the
+ # numeric formatting.
+ self._values[k] = [v, 1]
self._seen_so_far = current
now = time.time()
diff --git a/tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils.py b/tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils.py
index 231ace2a0b..48c2537727 100644
--- a/tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/multi_gpu_utils.py
@@ -34,7 +34,7 @@ def _normalize_device_name(name):
@tf_export('keras.utils.multi_gpu_model')
-def multi_gpu_model(model, gpus):
+def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False):
"""Replicates a model on different GPUs.
Specifically, this function implements single-machine
@@ -61,12 +61,18 @@ def multi_gpu_model(model, gpus):
(see usage example below).
gpus: Integer >= 2, number of on GPUs on which to create
model replicas.
+ cpu_merge: A boolean value to identify whether to force
+ merging model weights under the scope of the CPU or not.
+ cpu_relocation: A boolean value to identify whether to
+ create the model's weights under the scope of the CPU.
+ If the model is not defined under any preceding device
+ scope, you can still rescue it by activating this option.
Returns:
A Keras `Model` instance which can be used just like the initial
`model` argument, but which distributes its workload on multiple GPUs.
- Example:
+ Example 1: Training models with weights merge on CPU
```python
import tensorflow as tf
@@ -107,6 +113,39 @@ def multi_gpu_model(model, gpus):
model.save('my_model.h5')
```
+ Example 2: Training models with weights merge on CPU using cpu_relocation
+
+ ```python
+ ..
+ # Not needed to change the device scope for model definition:
+ model = Xception(weights=None, ..)
+
+ try:
+ model = multi_gpu_model(model, cpu_relocation=True)
+ print("Training using multiple GPUs..")
+ except:
+ print("Training using single GPU or CPU..")
+
+ model.compile(..)
+ ..
+ ```
+
+ Example 3: Training models with weights merge on GPU (recommended for NV-link)
+
+ ```python
+ ..
+ # Not needed to change the device scope for model definition:
+ model = Xception(weights=None, ..)
+
+ try:
+ model = multi_gpu_model(model, cpu_merge=False)
+ print("Training using multiple GPUs..")
+ except:
+ print("Training using single GPU or CPU..")
+ model.compile(..)
+ ..
+ ```
+
Raises:
ValueError: if the `gpus` argument does not match available devices.
"""
@@ -166,6 +205,12 @@ def multi_gpu_model(model, gpus):
start = stride * i
return array_ops.slice(data, start, size)
+ # Relocate the model definition under CPU device scope if needed
+ if cpu_relocation:
+ from tensorflow.python.keras._impl.keras.models import clone_model # pylint: disable=g-import-not-at-top
+ with ops.device('/cpu:0'):
+ model = clone_model(model)
+
all_outputs = []
for i in range(len(model.outputs)):
all_outputs.append([])
@@ -199,8 +244,8 @@ def multi_gpu_model(model, gpus):
for o in range(len(outputs)):
all_outputs[o].append(outputs[o])
- # Merge outputs on CPU.
- with ops.device('/cpu:0'):
+ # Merge outputs under expected scope.
+ with ops.device('/cpu:0' if cpu_merge else '/gpu:%d' % target_gpu_ids[0]):
merged = []
for name, outputs in zip(model.output_names, all_outputs):
merged.append(concatenate(outputs, axis=0, name=name))
diff --git a/tensorflow/python/keras/_impl/keras/utils/np_utils.py b/tensorflow/python/keras/_impl/keras/utils/np_utils.py
index a611be08aa..9d9c72b162 100644
--- a/tensorflow/python/keras/_impl/keras/utils/np_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/np_utils.py
@@ -43,7 +43,7 @@ def to_categorical(y, num_classes=None):
if not num_classes:
num_classes = np.max(y) + 1
n = y.shape[0]
- categorical = np.zeros((n, num_classes))
+ categorical = np.zeros((n, num_classes), dtype=np.float32)
categorical[np.arange(n), y] = 1
output_shape = input_shape + (num_classes,)
categorical = np.reshape(categorical, output_shape)
diff --git a/tensorflow/python/keras/utils/__init__.py b/tensorflow/python/keras/utils/__init__.py
index 2f74cf031d..9d924c8c90 100644
--- a/tensorflow/python/keras/utils/__init__.py
+++ b/tensorflow/python/keras/utils/__init__.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer
from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence
from tensorflow.python.keras._impl.keras.utils.data_utils import SequenceEnqueuer
from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index c892b6ee9a..83b353600a 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -742,6 +742,18 @@ tf_py_test(
)
tf_py_test(
+ name = "regex_full_match_op_test",
+ size = "small",
+ srcs = ["regex_full_match_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:string_ops",
+ ],
+)
+
+tf_py_test(
name = "save_restore_ops_test",
size = "small",
srcs = ["save_restore_ops_test.py"],
@@ -1222,6 +1234,7 @@ cuda_py_test(
shard_count = 10,
tags = [
"noasan", # times out
+ "optonly", # times out
],
)
@@ -2363,6 +2376,9 @@ cuda_py_test(
"//tensorflow/python:nn_grad",
"//tensorflow/python:nn_ops",
],
+ tags = [
+ "optonly", # flaky timeouts unless optimized
+ ],
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
index d6c0047747..13b804875e 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
@@ -1379,7 +1379,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
}
post_pruned_nodes_meta {
new_node_id: 0
- logit_change: -24.0143
+ logit_change: -24.014299
}
}
tree_metadata {
diff --git a/tensorflow/python/kernel_tests/control_flow_util_test.py b/tensorflow/python/kernel_tests/control_flow_util_test.py
index 39e96f74b0..762c445da0 100644
--- a/tensorflow/python/kernel_tests/control_flow_util_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_util_test.py
@@ -19,9 +19,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.platform import test
@@ -66,6 +70,111 @@ class ControlFlowUtilTest(test.TestCase):
self.assertFalse(control_flow_util.IsLoopExit(test_ops.int_output().op))
+ def build_test_graph(self):
+ g = ops.Graph()
+ with g.as_default():
+
+ def while_loop(x):
+
+ def b(x):
+ with ops.name_scope("NestedCond"):
+ return control_flow_ops.cond(
+ math_ops.less(x, 100), lambda: math_ops.add(x, 1),
+ lambda: math_ops.add(x, 2))
+
+ c = lambda x: math_ops.less(x, 10000)
+ with ops.name_scope("OuterWhile"):
+ return control_flow_ops.while_loop(c, b, [x])
+
+ x = array_ops.placeholder(dtypes.int32)
+ with ops.name_scope("OuterCond"):
+ control_flow_ops.cond(
+ math_ops.less(x, 1000), lambda: while_loop(x),
+ lambda: math_ops.add(x, 2))
+ return g
+
+ def testIsCondSwitch(self):
+ g = self.build_test_graph()
+
+ cond_switch = [
+ "OuterCond/cond/Switch",
+ "OuterCond/cond/OuterWhile/while/Switch",
+ "OuterCond/cond/OuterWhile/while/NestedCond/cond/Switch",
+ "OuterCond/cond/OuterWhile/while/NestedCond/cond/Add/Switch",
+ "OuterCond/cond/OuterWhile/while/NestedCond/cond/Add_1/Switch",
+ "OuterCond/cond/Add/Switch",
+ ]
+ for n in g.get_operations():
+ if control_flow_util.IsSwitch(n):
+ self.assertTrue(
+ control_flow_util.IsCondSwitch(n) != control_flow_util.IsLoopSwitch(
+ n))
+ if n.name in cond_switch:
+ self.assertTrue(control_flow_util.IsSwitch(n))
+ self.assertTrue(
+ control_flow_util.IsCondSwitch(n),
+ msg="Mismatch for {}".format(n.name))
+ self.assertFalse(
+ control_flow_util.IsLoopSwitch(n),
+ msg="Mismatch for {}".format(n.name))
+ else:
+ self.assertFalse(
+ control_flow_util.IsCondSwitch(n),
+ msg="Mismatch for {}".format(n.name))
+
+ def testIsLoopSwitch(self):
+ g = self.build_test_graph()
+
+ loop_switch = ["OuterCond/cond/OuterWhile/while/Switch_1"]
+ for n in g.get_operations():
+ if control_flow_util.IsSwitch(n):
+ self.assertTrue(
+ control_flow_util.IsCondSwitch(n) != control_flow_util.IsLoopSwitch(
+ n))
+ if n.name in loop_switch:
+ self.assertTrue(control_flow_util.IsSwitch(n))
+ self.assertFalse(
+ control_flow_util.IsCondSwitch(n),
+ msg="Mismatch for {}".format(n.name))
+ self.assertTrue(
+ control_flow_util.IsLoopSwitch(n),
+ msg="Mismatch for {}".format(n.name))
+ else:
+ self.assertFalse(
+ control_flow_util.IsLoopSwitch(n),
+ msg="Mismatch for {}".format(n.name))
+
+ def testIsCondMerge(self):
+ g = self.build_test_graph()
+ cond_merges = [
+ "OuterCond/cond/OuterWhile/while/NestedCond/cond/Merge",
+ "OuterCond/cond/Merge"
+ ]
+ for n in g.get_operations():
+ if n.name in cond_merges:
+ self.assertTrue(control_flow_util.IsMerge(n))
+ self.assertTrue(control_flow_util.IsCondMerge(n))
+ self.assertFalse(control_flow_util.IsLoopMerge(n))
+ else:
+ self.assertFalse(control_flow_util.IsCondMerge(n))
+ self.assertTrue(not control_flow_util.IsMerge(n) or
+ control_flow_util.IsLoopMerge(n))
+
+ def testIsLoopMerge(self):
+ g = self.build_test_graph()
+ loop_merges = [
+ "OuterCond/cond/OuterWhile/while/Merge",
+ ]
+ for n in g.get_operations():
+ if n.name in loop_merges:
+ self.assertTrue(control_flow_util.IsMerge(n))
+ self.assertFalse(control_flow_util.IsCondMerge(n))
+ self.assertTrue(control_flow_util.IsLoopMerge(n))
+ else:
+ self.assertFalse(control_flow_util.IsLoopMerge(n))
+ self.assertTrue(not control_flow_util.IsMerge(n) or
+ control_flow_util.IsCondMerge(n))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/conv3d_transpose_test.py b/tensorflow/python/kernel_tests/conv3d_transpose_test.py
index 8973a450fa..289ae29fce 100644
--- a/tensorflow/python/kernel_tests/conv3d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv3d_transpose_test.py
@@ -131,6 +131,23 @@ class Conv3DTransposeTest(test.TestCase):
nn_ops.conv3d_transpose(
x_value, f_value, y_shape, strides, data_format='NCDHW')
+ def testConv3DTransposeOutputShapeType(self):
+ # Test case for GitHub issue 18887
+ for dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session():
+ x_shape = [2, 5, 6, 4, 3]
+ y_shape = [2, 5, 6, 4, 2]
+ f_shape = [3, 3, 3, 2, 3]
+ strides = [1, 1, 1, 1, 1]
+ x_value = constant_op.constant(
+ 1.0, shape=x_shape, name="x", dtype=dtypes.float32)
+ f_value = constant_op.constant(
+ 1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
+ output = nn_ops.conv3d_transpose(
+ x_value, f_value, constant_op.constant(y_shape, dtype=dtype),
+ strides=strides, padding="SAME")
+ output.eval()
+
def testConv3DTransposeValid(self):
with self.test_session():
strides = [1, 2, 2, 2, 1]
diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD
index f3cc9636f9..cf2e8832fd 100644
--- a/tensorflow/python/kernel_tests/distributions/BUILD
+++ b/tensorflow/python/kernel_tests/distributions/BUILD
@@ -41,6 +41,7 @@ cuda_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
+ shard_count = 3,
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/distributions/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py
index 2d434a39c2..d5d50a180a 100644
--- a/tensorflow/python/kernel_tests/distributions/special_math_test.py
+++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py
@@ -23,11 +23,14 @@ import importlib
import numpy as np
+from tensorflow.python.eager import backprop as tfe_backprop
+from tensorflow.python.eager import context as tfe_context
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import variables
from tensorflow.python.ops.distributions import special_math
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
@@ -64,6 +67,16 @@ def _make_grid(dtype, grid_spec):
return np.reshape(grid, grid_spec.shape)
+def _value_and_gradient(fn, *args):
+ """Calls `fn` and computes the gradient of the result wrt `arg`."""
+ if tfe_context.executing_eagerly():
+ v, g = tfe_backprop.val_and_grad_function(fn)(args)
+ else:
+ v = fn(*args)
+ g = gradients_impl.gradients(v, args)
+ return v, g
+
+
GridSpec = collections.namedtuple("GridSpec", ["min", "max", "shape"])
ErrorSpec = collections.namedtuple("ErrorSpec", ["rtol", "atol"])
@@ -71,11 +84,12 @@ ErrorSpec = collections.namedtuple("ErrorSpec", ["rtol", "atol"])
class NdtriTest(test.TestCase):
- def assertAllFinite(self, tensor):
- is_finite = np.isfinite(tensor.eval())
+ def assertAllFinite(self, x):
+ is_finite = np.isfinite(x)
all_true = np.ones_like(is_finite, dtype=np.bool)
self.assertAllEqual(all_true, is_finite)
+ @test_util.run_in_graph_and_eager_modes()
def testNdtri(self):
"""Verifies that ndtri computation is correct."""
with self.test_session():
@@ -89,7 +103,7 @@ class NdtriTest(test.TestCase):
np.exp(-2), 1. - np.exp(-2)))
expected_x = special.ndtri(p)
x = special_math.ndtri(p)
- self.assertAllClose(expected_x, x.eval(), atol=0.)
+ self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
def testNdtriDynamicShape(self):
"""Verifies that ndtri computation is correct."""
@@ -108,23 +122,27 @@ class NdtriTest(test.TestCase):
def _baseNdtriFiniteGradientTest(self, dtype):
"""Verifies that ndtri has finite gradients at interesting points."""
- g = ops.Graph()
- with g.as_default():
- # Tests gradients at 0, 1, and piece-wise boundaries.
- p = variables.Variable(
- np.array([0.,
- np.exp(-32.), np.exp(-2.),
- 1. - np.exp(-2.), 1. - np.exp(-32.),
- 1.]).astype(dtype))
- value = special_math.ndtri(p)
- grads = gradients_impl.gradients(value, p)
- with self.test_session(graph=g):
- variables.global_variables_initializer().run()
- self.assertAllFinite(grads[0])
-
+ # Tests gradients at 0, 1, and piece-wise boundaries.
+ p = constant_op.constant(
+ np.array([
+ 0.,
+ np.exp(-32.),
+ np.exp(-2.),
+ 1. - np.exp(-2.),
+ 1. - np.exp(-32.),
+ 1.,
+ ]).astype(dtype))
+ # Not having the lambda sanitzer means we'd get an `IndexError` whenever
+ # the user supplied function has default args.
+ _, grads = _value_and_gradient(
+ lambda x: special_math.ndtri(x), p) # pylint: disable=unnecessary-lambda
+ self.assertAllFinite(self.evaluate(grads[0]))
+
+ @test_util.run_in_graph_and_eager_modes()
def testNdtriFiniteGradientFloat32(self):
self._baseNdtriFiniteGradientTest(np.float32)
+ @test_util.run_in_graph_and_eager_modes()
def testNdtriFiniteGradientFloat64(self):
self._baseNdtriFiniteGradientTest(np.float64)
@@ -147,55 +165,53 @@ class NdtrTest(test.TestCase):
if not special:
return
- with self.test_session():
- grid = _make_grid(dtype, grid_spec)
- actual = sm.log_ndtr(grid).eval()
-
- # Basic tests.
- # isfinite checks for NaN and Inf.
- self.assertTrue(np.isfinite(actual).all())
- # On the grid, -inf < log_cdf(x) < 0. In this case, we should be able
- # to use a huge grid because we have used tricks to escape numerical
- # difficulties.
- self.assertTrue((actual < 0).all())
- _check_strictly_increasing(actual)
-
- # Versus scipy.
- expected = special.log_ndtr(grid)
- # Scipy prematurely goes to zero at some places that we don't. So don't
- # include these in the comparison.
- self.assertAllClose(
- expected.astype(np.float64)[expected < 0],
- actual.astype(np.float64)[expected < 0],
- rtol=error_spec.rtol,
- atol=error_spec.atol)
+ grid = _make_grid(dtype, grid_spec)
+ actual = self.evaluate(sm.log_ndtr(grid))
+
+ # Basic tests.
+ # isfinite checks for NaN and Inf.
+ self.assertTrue(np.isfinite(actual).all())
+ # On the grid, -inf < log_cdf(x) < 0. In this case, we should be able
+ # to use a huge grid because we have used tricks to escape numerical
+ # difficulties.
+ self.assertTrue((actual < 0).all())
+ _check_strictly_increasing(actual)
+
+ # Versus scipy.
+ expected = special.log_ndtr(grid)
+ # Scipy prematurely goes to zero at some places that we don't. So don't
+ # include these in the comparison.
+ self.assertAllClose(
+ expected.astype(np.float64)[expected < 0],
+ actual.astype(np.float64)[expected < 0],
+ rtol=error_spec.rtol,
+ atol=error_spec.atol)
def _test_grid_no_log(self, dtype, grid_spec, error_spec):
if not special:
return
- with self.test_session():
- grid = _make_grid(dtype, grid_spec)
- actual = sm.ndtr(grid).eval()
-
- # Basic tests.
- # isfinite checks for NaN and Inf.
- self.assertTrue(np.isfinite(actual).all())
- # On the grid, 0 < cdf(x) < 1. The grid cannot contain everything due
- # to numerical limitations of cdf.
- self.assertTrue((actual > 0).all())
- self.assertTrue((actual < 1).all())
- _check_strictly_increasing(actual)
-
- # Versus scipy.
- expected = special.ndtr(grid)
- # Scipy prematurely goes to zero at some places that we don't. So don't
- # include these in the comparison.
- self.assertAllClose(
- expected.astype(np.float64)[expected < 0],
- actual.astype(np.float64)[expected < 0],
- rtol=error_spec.rtol,
- atol=error_spec.atol)
+ grid = _make_grid(dtype, grid_spec)
+ actual = self.evaluate(sm.ndtr(grid))
+
+ # Basic tests.
+ # isfinite checks for NaN and Inf.
+ self.assertTrue(np.isfinite(actual).all())
+ # On the grid, 0 < cdf(x) < 1. The grid cannot contain everything due
+ # to numerical limitations of cdf.
+ self.assertTrue((actual > 0).all())
+ self.assertTrue((actual < 1).all())
+ _check_strictly_increasing(actual)
+
+ # Versus scipy.
+ expected = special.ndtr(grid)
+ # Scipy prematurely goes to zero at some places that we don't. So don't
+ # include these in the comparison.
+ self.assertAllClose(
+ expected.astype(np.float64)[expected < 0],
+ actual.astype(np.float64)[expected < 0],
+ rtol=error_spec.rtol,
+ atol=error_spec.atol)
def test_float32(self):
self._test_grid(np.float32, self._grid32, self._error32)
@@ -254,14 +270,17 @@ class NdtrGradientTest(test.TestCase):
self.assertAllEqual(np.zeros_like(v, dtype=np.bool), v)
def _test_grad_finite(self, dtype):
- with self.test_session():
- x = variables.Variable([-100., 0., 100.], dtype=dtype)
- output = (sm.log_ndtr(x) if self._use_log else sm.ndtr(x))
- grad_output = gradients_impl.gradients(output, x)
- variables.global_variables_initializer().run()
- # isfinite checks for NaN and Inf.
- self.assert_all_true(np.isfinite(output.eval()))
- self.assert_all_true(np.isfinite(grad_output[0].eval()))
+ x = constant_op.constant([-100., 0., 100.], dtype=dtype)
+ output = (sm.log_ndtr(x) if self._use_log else sm.ndtr(x))
+ fn = sm.log_ndtr if self._use_log else sm.ndtr
+ # Not having the lambda sanitzer means we'd get an `IndexError` whenever
+ # the user supplied function has default args.
+ output, grad_output = _value_and_gradient(
+ lambda x_: fn(x_), x) # pylint: disable=unnecessary-lambda
+ # isfinite checks for NaN and Inf.
+ output_, grad_output_ = self.evaluate([output, grad_output])
+ self.assert_all_true(np.isfinite(output_))
+ self.assert_all_true(np.isfinite(grad_output_[0]))
def _test_grad_accuracy(self, dtype, grid_spec, error_spec):
raw_grid = _make_grid(dtype, grid_spec)
@@ -357,7 +376,6 @@ class ErfInvTest(test.TestCase):
special_math.erfinv(x)
-
class LogCDFLaplaceTest(test.TestCase):
# Note that scipy.stats.laplace does not have a stable Log CDF, so we cannot
# rely on scipy to cross check the extreme values.
diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py
index dc465c867f..619a81bea5 100644
--- a/tensorflow/python/kernel_tests/distributions/util_test.py
+++ b/tensorflow/python/kernel_tests/distributions/util_test.py
@@ -1017,6 +1017,62 @@ class SoftplusTest(test.TestCase):
self.assertAllEqual(
np.ones_like(grads).astype(np.bool), np.isfinite(grads))
+class ArgumentsTest(test.TestCase):
+
+ def testNoArguments(self):
+ def foo():
+ return du.parent_frame_arguments()
+
+ self.assertEqual({}, foo())
+
+ def testPositionalArguments(self):
+ def foo(a, b, c, d): # pylint: disable=unused-argument
+ return du.parent_frame_arguments()
+
+ self.assertEqual({"a": 1, "b": 2, "c": 3, "d": 4}, foo(1, 2, 3, 4))
+
+ # Tests that it does not matter where this function is called, and
+ # no other local variables are returned back.
+ def bar(a, b, c):
+ unused_x = a * b
+ unused_y = c * 3
+ return du.parent_frame_arguments()
+
+ self.assertEqual({"a": 1, "b": 2, "c": 3}, bar(1, 2, 3))
+
+ def testOverloadedArgumentValues(self):
+ def foo(a, b, c): # pylint: disable=unused-argument
+ a = 42
+ b = 31
+ c = 42
+ return du.parent_frame_arguments()
+ self.assertEqual({"a": 42, "b": 31, "c": 42}, foo(1, 2, 3))
+
+ def testKeywordArguments(self):
+ def foo(**kwargs): # pylint: disable=unused-argument
+ return du.parent_frame_arguments()
+
+ self.assertEqual({"a": 1, "b": 2, "c": 3, "d": 4}, foo(a=1, b=2, c=3, d=4))
+
+ def testPositionalKeywordArgs(self):
+ def foo(a, b, c, **kwargs): # pylint: disable=unused-argument
+ return du.parent_frame_arguments()
+
+ self.assertEqual({"a": 1, "b": 2, "c": 3}, foo(a=1, b=2, c=3))
+ self.assertEqual({"a": 1, "b": 2, "c": 3, "unicorn": None},
+ foo(a=1, b=2, c=3, unicorn=None))
+
+ def testNoVarargs(self):
+ def foo(a, b, c, *varargs, **kwargs): # pylint: disable=unused-argument
+ return du.parent_frame_arguments()
+
+ self.assertEqual({"a": 1, "b": 2, "c": 3}, foo(a=1, b=2, c=3))
+ self.assertEqual({"a": 1, "b": 2, "c": 3}, foo(1, 2, 3, *[1, 2, 3]))
+ self.assertEqual({"a": 1, "b": 2, "c": 3, "unicorn": None},
+ foo(1, 2, 3, unicorn=None))
+ self.assertEqual({"a": 1, "b": 2, "c": 3, "unicorn": None},
+ foo(1, 2, 3, *[1, 2, 3], unicorn=None))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 35a274e75f..5489338bc0 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for tensorflow.kernels.bcast_ops."""
+"""Tests for tensorflow.kernels.functional_ops."""
from __future__ import absolute_import
from __future__ import division
@@ -670,117 +670,117 @@ class FunctionalOpsTest(test.TestCase):
with self.test_session(use_gpu=False) as sess:
- def Run(x):
- return sess.run(
- functional_ops.If(math_ops.greater(x, 0), [x], Twice, Thrice))[0]
+ x = array_ops.placeholder(dtypes.float32)
+ ret = functional_ops.If(math_ops.greater(x, 0), [x], Twice, Thrice)[0]
- self.assertAllEqual(Run(9.), 18.)
- self.assertAllEqual(Run(-8.), -23.)
- self.assertAllEqual(Run(0.), 1.)
+ self.assertAllEqual(sess.run(ret, feed_dict={x: 9.}), 18.)
+ self.assertAllEqual(sess.run(ret, feed_dict={x: -8.}), -23.)
+ self.assertAllEqual(sess.run(ret, feed_dict={x: 0.}), 1.)
def testWhile(self):
- @function.Defun(*[dtypes.float32] * 2)
- def Cond(n, unused_x):
- return n > 0
+ for use_gpu in (True, False):
+ with ops.Graph().as_default() as g:
- @function.Defun(*[dtypes.float32] * 2)
- def Body(n, x):
- return n - 1, x + n
+ @function.Defun(*[dtypes.float32] * 2)
+ def Cond(n, unused_x):
+ return n > 0
- # TODO(b/65752372): Set `use_gpu=False` because
- # `functional_ops.While()` does not reliably work on GPU (apparently
- # because the result of evaluating the condition may be in device
- # memory, but it is read on the host).
- with self.test_session(use_gpu=False) as sess:
+ @function.Defun(*[dtypes.float32] * 2)
+ def Body(n, x):
+ return n - 1, x + n
- def Run(n):
- return sess.run(functional_ops.While([n, 0.], Cond, Body))[1]
+ def Run(sess, n):
+ return sess.run(functional_ops.While([n, 0.], Cond, Body))[1]
- self.assertAllEqual(Run(20.), 210.)
- self.assertAllEqual(Run(100.), 5050.)
+ with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+ self.assertAllEqual(Run(sess, 20.), 210.)
+ self.assertAllEqual(Run(sess, 100.), 5050.)
def testWhileError(self):
-
- @function.Defun(*[dtypes.float32] * 2)
- def Cond(n, unused_x):
- return n > 0
-
- @function.Defun(*[dtypes.float32] * 2)
- def CondReturnsTooManyArgs(n, x):
- return n > 0, x
-
- @function.Defun(*[dtypes.float32] * 2)
- def Body(n, x):
- return n - 1, x + n
-
- @function.Defun(*[dtypes.float32] * 2)
- def BodyReturnsTooManyArgs(n, x):
- return n - 1, x + n, x
-
- # TODO(b/65752372): Set `use_gpu=False` because
- # `functional_ops.While()` does not reliably work on GPU (apparently
- # because the result of evaluating the condition may be in device
- # memory, but it is read on the host).
- with self.test_session(use_gpu=False):
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "Expected a single scalar.*got 2 tensors."):
- functional_ops.While([5., 0.], CondReturnsTooManyArgs, Body)[0].eval()
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "While loop body returned 3 arguments. Expected: 2"):
- functional_ops.While([5., 0.], Cond, BodyReturnsTooManyArgs)[0].eval()
+ for use_gpu in (True, False):
+ with ops.Graph().as_default() as g:
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Cond(n, unused_x):
+ return n > 0
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def CondReturnsTooManyArgs(n, x):
+ return n > 0, x
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Body(n, x):
+ return n - 1, x + n
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def BodyReturnsTooManyArgs(n, x):
+ return n - 1, x + n, x
+
+ with self.test_session(graph=g, use_gpu=use_gpu):
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Expected a single scalar.*got 2 tensors."):
+ functional_ops.While([5., 0.], CondReturnsTooManyArgs,
+ Body)[0].eval()
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "While loop body returned 3 arguments. Expected: 2"):
+ functional_ops.While([5., 0.], Cond,
+ BodyReturnsTooManyArgs)[0].eval()
def testWhileInMultipleSubgraphs(self):
- @function.Defun(* [dtypes.float32] * 2)
- def Cond(n, x): # pylint: disable=unused-argument
- return n > 0
-
- @function.Defun(* [dtypes.float32] * 2)
- def Body(n, x):
- return n - 1, x + n
-
- # TODO(b/65752372): Set `use_gpu=False` because
- # `functional_ops.While()` does not reliably work on GPU (apparently
- # because the result of evaluating the condition may be in device
- # memory, but it is read on the host).
- with self.test_session(use_gpu=False) as sess:
- n = array_ops.placeholder(dtypes.float32)
- _, result = functional_ops.While([n, 0.], Cond, Body)
- c = constant_op.constant(37.)
-
- self.assertAllEqual(210., sess.run(result, feed_dict={n: 20.}))
- self.assertAllEqual(5050., sess.run(result, feed_dict={n: 100.}))
- # Test that the result is the same when we run a different subgraph.
- self.assertAllEqual(5050., sess.run([result, c], feed_dict={n: 100.})[0])
-
- def _tfSum(self, rewrite_with_while):
- # On GPU, don't rewrite using a while loop.
- use_gpu = not rewrite_with_while
- with self.test_session(use_gpu=use_gpu) as sess:
-
- @function.Defun(dtypes.int32, dtypes.float32)
- def Body(n, x):
- return x + math_ops.to_float(n)
-
- xs = [
- # 1 + 2 + ... + 20
- functional_ops.For(
- 1, 21, 1, [0.], Body, rewrite_with_while=rewrite_with_while)[0],
- # 100 + 99 + ... + 1
- functional_ops.For(
- 100, 0, -1, [0.], Body, rewrite_with_while=rewrite_with_while)[0],
- ]
- xvals = sess.run(xs)
- self.assertAllEqual(210, xvals[0])
- self.assertAllEqual(5050, xvals[1])
+ for use_gpu in (True, False):
+ with ops.Graph().as_default() as g:
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Cond(n, x): # pylint: disable=unused-argument
+ return n > 0
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Body(n, x):
+ return n - 1, x + n
+
+ with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+ n = array_ops.placeholder(dtypes.float32)
+ _, result = functional_ops.While([n, 0.], Cond, Body)
+ c = constant_op.constant(37.)
+
+ self.assertAllEqual(210., sess.run(result, feed_dict={n: 20.}))
+ self.assertAllEqual(5050., sess.run(result, feed_dict={n: 100.}))
+ # Test that the result is the same when we run a different subgraph.
+ self.assertAllEqual(5050.,
+ sess.run([result, c], feed_dict={n: 100.})[0])
+
+ def _tfSum(self, use_gpu, rewrite_with_while):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+
+ @function.Defun(dtypes.int32, dtypes.float32)
+ def Body(n, x):
+ return x + math_ops.to_float(n)
+
+ xs = [
+ # 1 + 2 + ... + 20
+ functional_ops.For(
+ 1, 21, 1, [0.], Body, rewrite_with_while=rewrite_with_while)[0],
+ # 100 + 99 + ... + 1
+ functional_ops.For(
+ 100, 0, -1, [0.], Body, rewrite_with_while=rewrite_with_while)
+ [0],
+ ]
+ xvals = sess.run(xs)
+ self.assertAllEqual(210, xvals[0])
+ self.assertAllEqual(5050, xvals[1])
def testFor(self):
- self._tfSum(False)
+ for use_gpu in (True, False):
+ self._tfSum(use_gpu, False)
def testForWithWhile(self):
- self._tfSum(True)
+ for use_gpu in (True, False):
+ self._tfSum(use_gpu, True)
def testForWithWhileNaming(self):
g = ops.Graph()
@@ -816,10 +816,6 @@ class FunctionalOpsTest(test.TestCase):
return x + math_ops.to_float(n) + v, x2 + v
for rewrite_with_while in (True, False):
- # TODO(b/65752372): Set `use_gpu=False` because
- # `functional_ops.While()` does not reliably work on GPU (apparently
- # because the result of evaluating the condition may be in device
- # memory, but it is read on the host).
use_gpu = not rewrite_with_while
with self.test_session(use_gpu=use_gpu) as sess:
result_nullary = functional_ops.For(
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
index cce1ecd45e..784c730bbc 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
@@ -97,6 +97,10 @@ class SquareLinearOperatorKroneckerTest(
build_info((3, 6, 6), factors=[(3, 1, 1), (1, 2, 2), (1, 3, 3)]),
]
+ @property
+ def _tests_to_skip(self):
+ return ["det", "solve", "solve_with_broadcast"]
+
def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
expected_factors = build_info.__dict__["factors"]
diff --git a/tensorflow/python/kernel_tests/regex_full_match_op_test.py b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
new file mode 100644
index 0000000000..5daae1b79b
--- /dev/null
+++ b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
@@ -0,0 +1,54 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for RegexFullMatch op from string_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class RegexFullMatchOpTest(test.TestCase):
+
+ def testRegexFullMatch(self):
+ values = ["abaaba", "abcdabcde"]
+ with self.test_session():
+ input_vector = constant_op.constant(values, dtypes.string)
+ matched = string_ops.regex_full_match(input_vector, "a.*a").eval()
+ self.assertAllEqual([True, False], matched)
+
+ def testEmptyMatch(self):
+ values = ["abc", "1"]
+ with self.test_session():
+ input_vector = constant_op.constant(values, dtypes.string)
+ matched = string_ops.regex_full_match(input_vector, "").eval()
+ self.assertAllEqual([False, False], matched)
+
+ def testInvalidPattern(self):
+ values = ["abc", "1"]
+ with self.test_session():
+ input_vector = constant_op.constant(values, dtypes.string)
+ invalid_pattern = "A["
+ matched = string_ops.regex_full_match(input_vector, invalid_pattern)
+ with self.assertRaisesOpError("Invalid pattern"):
+ matched.eval()
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index b7477a768a..79fe927b8a 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -23,8 +23,11 @@ import functools
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
@@ -364,6 +367,15 @@ class ScatterNdTest(test.TestCase):
del input_ # input_ is not used in scatter_nd
return array_ops.scatter_nd(indices, updates, shape)
+ @test_util.run_in_graph_and_eager_modes()
+ def testInvalidShape(self):
+ # TODO(apassos) figure out how to unify these errors
+ with self.assertRaises(errors.InvalidArgumentError
+ if context.executing_eagerly() else ValueError):
+ array_ops.scatter_nd(indices=[0], # this should be indices=[[0]]
+ updates=[0.0],
+ shape=[1])
+
def testString(self):
indices = constant_op.constant([[4], [3], [1], [7]],
dtype=dtypes.int32)
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index 3bca5fadc4..794be096b7 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -91,16 +91,18 @@ class SegmentReductionOpTest(SegmentReductionHelper):
]
# Each item is np_op1, np_op2, tf_op
- ops_list = [(np.add, None, math_ops.segment_sum), (self._mean_cum_op,
- self._mean_reduce_op,
- math_ops.segment_mean),
+ ops_list = [(np.add, None, math_ops.segment_sum),
+ (self._mean_cum_op, self._mean_reduce_op,
+ math_ops.segment_mean),
(np.ndarray.__mul__, None, math_ops.segment_prod),
(np.minimum, None, math_ops.segment_min),
(np.maximum, None, math_ops.segment_max)]
# A subset of ops has been enabled for complex numbers
complex_ops_list = [(np.add, None, math_ops.segment_sum),
- (np.ndarray.__mul__, None, math_ops.segment_prod)]
+ (np.ndarray.__mul__, None, math_ops.segment_prod),
+ (self._mean_cum_op, self._mean_reduce_op,
+ math_ops.segment_mean)]
n = 10
shape = [n, 2]
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 51aa671098..9dc4ec0f96 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -40,6 +40,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
+from tensorflow.python.util import compat
class VariableScopeTest(test.TestCase):
@@ -110,6 +111,12 @@ class VariableScopeTest(test.TestCase):
w = variable_scope.get_variable("w", [])
self.assertEqual(w.constraint, constraint)
+ def testStringDefaultInitializer(self):
+ with self.test_session():
+ v = variable_scope.get_variable("string", shape=[], dtype=dtypes.string)
+ variables_lib.global_variables_initializer().run()
+ self.assertAllEqual(compat.as_bytes(v.eval()), b"")
+
@test_util.run_in_graph_and_eager_modes()
def testVarScopeDType(self):
with variable_scope.variable_scope("tower2") as tower:
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index aa43a153c2..1cf7d2abd1 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -20,12 +20,12 @@ from __future__ import print_function
import copy
from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras._impl.keras.engine import base_layer
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -319,7 +319,7 @@ class Layer(base_layer.Layer):
try:
call_has_scope_arg = self._call_has_scope_arg
except AttributeError:
- self._call_fn_args = estimator_util.fn_args(self.call)
+ self._call_fn_args = function_utils.fn_args(self.call)
self._call_has_scope_arg = 'scope' in self._call_fn_args
call_has_scope_arg = self._call_has_scope_arg
if call_has_scope_arg:
diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc
index a07e305ffb..9df38d464c 100644
--- a/tensorflow/python/lib/core/ndarray_tensor.cc
+++ b/tensorflow/python/lib/core/ndarray_tensor.cc
@@ -145,7 +145,7 @@ Status PyBytesArrayMap(PyArrayObject* array, F f) {
while (PyArray_ITER_NOTDONE(iter.get())) {
auto item = tensorflow::make_safe(PyArray_GETITEM(
array, static_cast<char*>(PyArray_ITER_DATA(iter.get()))));
- if (!item.get()) {
+ if (!item) {
return errors::Internal("Unable to get element from the feed - no item.");
}
char* ptr;
diff --git a/tensorflow/python/lib/core/ndarray_tensor_bridge.cc b/tensorflow/python/lib/core/ndarray_tensor_bridge.cc
index 65e2178cda..0d5838505f 100644
--- a/tensorflow/python/lib/core/ndarray_tensor_bridge.cc
+++ b/tensorflow/python/lib/core/ndarray_tensor_bridge.cc
@@ -72,10 +72,11 @@ struct TensorReleaser {
extern PyTypeObject TensorReleaserType;
-static void TensorReleaser_dealloc(TensorReleaser* self) {
+static void TensorReleaser_dealloc(PyObject* pself) {
+ TensorReleaser* self = reinterpret_cast<TensorReleaser*>(pself);
(*self->destructor)();
delete self->destructor;
- TensorReleaserType.tp_free(self);
+ TensorReleaserType.tp_free(pself);
}
PyTypeObject TensorReleaserType = {
@@ -84,26 +85,26 @@ PyTypeObject TensorReleaserType = {
sizeof(TensorReleaser), /* tp_basicsize */
0, /* tp_itemsize */
/* methods */
- (destructor)TensorReleaser_dealloc, /* tp_dealloc */
- nullptr, /* tp_print */
- nullptr, /* tp_getattr */
- nullptr, /* tp_setattr */
- nullptr, /* tp_compare */
- nullptr, /* tp_repr */
- nullptr, /* tp_as_number */
- nullptr, /* tp_as_sequence */
- nullptr, /* tp_as_mapping */
- nullptr, /* tp_hash */
- nullptr, /* tp_call */
- nullptr, /* tp_str */
- nullptr, /* tp_getattro */
- nullptr, /* tp_setattro */
- nullptr, /* tp_as_buffer */
- Py_TPFLAGS_DEFAULT, /* tp_flags */
- "Wrapped TensorFlow Tensor", /* tp_doc */
- nullptr, /* tp_traverse */
- nullptr, /* tp_clear */
- nullptr, /* tp_richcompare */
+ TensorReleaser_dealloc, /* tp_dealloc */
+ nullptr, /* tp_print */
+ nullptr, /* tp_getattr */
+ nullptr, /* tp_setattr */
+ nullptr, /* tp_compare */
+ nullptr, /* tp_repr */
+ nullptr, /* tp_as_number */
+ nullptr, /* tp_as_sequence */
+ nullptr, /* tp_as_mapping */
+ nullptr, /* tp_hash */
+ nullptr, /* tp_call */
+ nullptr, /* tp_str */
+ nullptr, /* tp_getattro */
+ nullptr, /* tp_setattro */
+ nullptr, /* tp_as_buffer */
+ Py_TPFLAGS_DEFAULT, /* tp_flags */
+ "Wrapped TensorFlow Tensor", /* tp_doc */
+ nullptr, /* tp_traverse */
+ nullptr, /* tp_clear */
+ nullptr, /* tp_richcompare */
};
Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 5f60dab6ac..5ebdb19079 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1685,12 +1685,12 @@ class CondContext(ControlFlowContext):
self._pivot = pivot # The predicate tensor in this branch
self._branch = branch # 0 or 1 representing this branch
- # Values considered to have been already seen in this context. They are
- # not included in this context.
+ # Values considered to have been already seen in this context. pred is not
+ # included in this context.
self._values.add(pred.name)
self._external_values[pred.name] = pred
self._values.add(pivot.name)
- self._external_values[pivot.name] = pivot
+ pivot.op._set_control_flow_context(self) # pylint: disable=protected-access
def _init_from_proto(self, context_def, import_scope=None):
"""Creates a new `CondContext` from protocol buffer.
diff --git a/tensorflow/python/ops/control_flow_util.py b/tensorflow/python/ops/control_flow_util.py
index eee31102db..7a18986c5b 100644
--- a/tensorflow/python/ops/control_flow_util.py
+++ b/tensorflow/python/ops/control_flow_util.py
@@ -53,6 +53,11 @@ def IsSwitch(op):
return op.type == "Switch" or op.type == "RefSwitch"
+def IsMerge(op):
+ """Return true if `op` is a Merge."""
+ return op.type == "Merge" or op.type == "RefMerge"
+
+
def IsLoopEnter(op):
"""Returns true if `op` is an Enter."""
return op.type == "Enter" or op.type == "RefEnter"
@@ -63,11 +68,57 @@ def IsLoopExit(op):
return op.type == "Exit" or op.type == "RefExit"
+def IsCondSwitch(op):
+ """Return true if `op` is the Switch for a conditional."""
+ if not IsSwitch(op):
+ return False
+ if not op.outputs:
+ return False
+ # Switch nodes are not part of the cond control flow context that they
+ # represent, so consider the consumers of its outputs to determine if it is
+ # cond switch or not. A switch is a cond switch iff all its consumers are in
+ # cond contexts.
+ is_cond_switch = True
+ for o in op.outputs:
+ for c in o.consumers():
+ ctxt = c._get_control_flow_context() # pylint: disable=protected-access
+ if IsLoopEnter(c):
+ ctxt = ctxt.outer_context
+ is_cond_switch = is_cond_switch and (ctxt is not None and
+ ctxt.IsCondContext())
+ return is_cond_switch
+
+
+def IsCondMerge(op):
+ """Return true if `op` is the Merge for a conditional."""
+ if not IsMerge(op):
+ return False
+ if not op.inputs:
+ return False
+ # Merge nodes are not part of the cond control flow context that they
+ # represent, so consider the inputs to the merge of to determine if it is
+ # cond merge or not: A merge is a cond merge iff all its inputs are in
+ # cond contexts.
+ is_cond_merge = True
+ for i in op.inputs:
+ ctxt = GetOutputContext(i.op)
+ is_cond_merge = is_cond_merge and ctxt is not None and ctxt.IsCondContext()
+ return is_cond_merge
+
+
def IsLoopSwitch(op):
"""Return true if `op` is the Switch for a while loop."""
if IsSwitch(op):
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
- return ctxt and ctxt.IsWhileContext()
+ return ctxt is not None and ctxt.IsWhileContext() and not IsCondSwitch(op)
+ return False
+
+
+def IsLoopMerge(op):
+ """Return true if `op` is the Merge for a while loop."""
+ if IsMerge(op):
+ ctxt = op._get_control_flow_context() # pylint: disable=protected-access
+ return ctxt is not None and ctxt.IsWhileContext() and not IsCondMerge(op)
return False
diff --git a/tensorflow/python/ops/distributions/bernoulli.py b/tensorflow/python/ops/distributions/bernoulli.py
index 2c9f0e9a32..d7fb3f1f78 100644
--- a/tensorflow/python/ops/distributions/bernoulli.py
+++ b/tensorflow/python/ops/distributions/bernoulli.py
@@ -71,7 +71,7 @@ class Bernoulli(distribution.Distribution):
Raises:
ValueError: If p and logits are passed, or if neither are passed.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits=logits,
diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py
index 8beab99bf8..b697848600 100644
--- a/tensorflow/python/ops/distributions/beta.py
+++ b/tensorflow/python/ops/distributions/beta.py
@@ -150,7 +150,7 @@ class Beta(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[concentration1, concentration0]) as name:
self._concentration1 = self._maybe_assert_valid_concentration(
ops.convert_to_tensor(concentration1, name="concentration1"),
@@ -321,7 +321,7 @@ class BetaWithSoftplusConcentration(Beta):
validate_args=False,
allow_nan_stats=True,
name="BetaWithSoftplusConcentration"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[concentration1,
concentration0]) as name:
super(BetaWithSoftplusConcentration, self).__init__(
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index 8f25b1149c..bbdc8c455a 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -182,7 +182,7 @@ class Categorical(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[logits, probs]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits=logits,
diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py
index eafcd5c78f..8d0d1d860b 100644
--- a/tensorflow/python/ops/distributions/dirichlet.py
+++ b/tensorflow/python/ops/distributions/dirichlet.py
@@ -154,7 +154,7 @@ class Dirichlet(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[concentration]) as name:
self._concentration = self._maybe_assert_valid_concentration(
ops.convert_to_tensor(concentration, name="concentration"),
diff --git a/tensorflow/python/ops/distributions/dirichlet_multinomial.py b/tensorflow/python/ops/distributions/dirichlet_multinomial.py
index fe0ed7e07d..3a35e0caa0 100644
--- a/tensorflow/python/ops/distributions/dirichlet_multinomial.py
+++ b/tensorflow/python/ops/distributions/dirichlet_multinomial.py
@@ -191,7 +191,7 @@ class DirichletMultinomial(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[total_count, concentration]) as name:
# Broadcasting works because:
# * The broadcasting convention is to prepend dimensions of size [1], and
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 3815abf72d..fd08bda9b9 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -524,7 +524,8 @@ class Distribution(_BaseDistribution):
def parameters(self):
"""Dictionary of parameters used to instantiate this `Distribution`."""
# Remove "self", "__class__", or other special variables. These can appear
- # if the subclass used `parameters = locals()`.
+ # if the subclass used:
+ # `parameters = distribution_util.parent_frame_arguments()`.
return dict((k, v) for k, v in self._parameters.items()
if not k.startswith("__") and k != "self")
diff --git a/tensorflow/python/ops/distributions/exponential.py b/tensorflow/python/ops/distributions/exponential.py
index cf0e729e1a..1e08f48d52 100644
--- a/tensorflow/python/ops/distributions/exponential.py
+++ b/tensorflow/python/ops/distributions/exponential.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import gamma
+from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
@@ -90,7 +91,7 @@ class Exponential(gamma.Gamma):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
# Even though all statistics of are defined for valid inputs, this is not
# true in the parent class "Gamma." Therefore, passing
# allow_nan_stats=True
@@ -143,7 +144,7 @@ class ExponentialWithSoftplusRate(Exponential):
validate_args=False,
allow_nan_stats=True,
name="ExponentialWithSoftplusRate"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[rate]) as name:
super(ExponentialWithSoftplusRate, self).__init__(
rate=nn.softplus(rate, name="softplus_rate"),
diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py
index d39f7c56d3..7ca690d9d2 100644
--- a/tensorflow/python/ops/distributions/gamma.py
+++ b/tensorflow/python/ops/distributions/gamma.py
@@ -126,7 +126,7 @@ class Gamma(distribution.Distribution):
Raises:
TypeError: if `concentration` and `rate` are different dtypes.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[concentration, rate]) as name:
with ops.control_dependencies([
check_ops.assert_positive(concentration),
@@ -261,7 +261,7 @@ class GammaWithSoftplusConcentrationRate(Gamma):
validate_args=False,
allow_nan_stats=True,
name="GammaWithSoftplusConcentrationRate"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[concentration, rate]) as name:
super(GammaWithSoftplusConcentrationRate, self).__init__(
concentration=nn.softplus(concentration,
diff --git a/tensorflow/python/ops/distributions/laplace.py b/tensorflow/python/ops/distributions/laplace.py
index 3ccfc618d1..ee3a6a40ff 100644
--- a/tensorflow/python/ops/distributions/laplace.py
+++ b/tensorflow/python/ops/distributions/laplace.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import special_math
+from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
@@ -100,7 +101,7 @@ class Laplace(distribution.Distribution):
Raises:
TypeError: if `loc` and `scale` are of different dtype.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
@@ -217,7 +218,7 @@ class LaplaceWithSoftplusScale(Laplace):
validate_args=False,
allow_nan_stats=True,
name="LaplaceWithSoftplusScale"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
super(LaplaceWithSoftplusScale, self).__init__(
loc=loc,
diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py
index ab77f5c1f8..036ba45ccc 100644
--- a/tensorflow/python/ops/distributions/multinomial.py
+++ b/tensorflow/python/ops/distributions/multinomial.py
@@ -182,7 +182,7 @@ class Multinomial(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[total_count, logits, probs]) as name:
self._total_count = ops.convert_to_tensor(total_count, name="total_count")
if validate_args:
diff --git a/tensorflow/python/ops/distributions/normal.py b/tensorflow/python/ops/distributions/normal.py
index 20d4420e91..0620aae10d 100644
--- a/tensorflow/python/ops/distributions/normal.py
+++ b/tensorflow/python/ops/distributions/normal.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import special_math
+from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
@@ -131,7 +132,7 @@ class Normal(distribution.Distribution):
Raises:
TypeError: if `loc` and `scale` have different `dtype`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
@@ -243,7 +244,7 @@ class NormalWithSoftplusScale(Normal):
validate_args=False,
allow_nan_stats=True,
name="NormalWithSoftplusScale"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[scale]) as name:
super(NormalWithSoftplusScale, self).__init__(
loc=loc,
diff --git a/tensorflow/python/ops/distributions/special_math.py b/tensorflow/python/ops/distributions/special_math.py
index 1d605c5dfc..31b7a36fd3 100644
--- a/tensorflow/python/ops/distributions/special_math.py
+++ b/tensorflow/python/ops/distributions/special_math.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import math
import numpy as np
from tensorflow.python.framework import constant_op
@@ -42,15 +41,15 @@ __all__ = [
# then made more conservative just to be safe. (Conservative means use the
# expansion more than we probably need to.) See `NdtrTest` in
# special_math_test.py.
-LOGNDTR_FLOAT64_LOWER = -20
-LOGNDTR_FLOAT32_LOWER = -10
+LOGNDTR_FLOAT64_LOWER = np.array(-20, np.float64)
+LOGNDTR_FLOAT32_LOWER = np.array(-10, np.float32)
# Upper bound values were chosen by examining for which values of 'x'
# Log[cdf(x)] is 0, after which point we need to use the approximation
# Log[cdf(x)] = Log[1 - cdf(-x)] approx -cdf(-x). We chose a value slightly
# conservative, meaning we use the approximation earlier than needed.
-LOGNDTR_FLOAT64_UPPER = 8
-LOGNDTR_FLOAT32_UPPER = 5
+LOGNDTR_FLOAT64_UPPER = np.array(8, np.float64)
+LOGNDTR_FLOAT32_UPPER = np.array(5, np.float32)
def ndtr(x, name="ndtr"):
@@ -91,7 +90,7 @@ def ndtr(x, name="ndtr"):
def _ndtr(x):
"""Implements ndtr core logic."""
half_sqrt_2 = constant_op.constant(
- 0.5 * math.sqrt(2.), dtype=x.dtype, name="half_sqrt_2")
+ 0.5 * np.sqrt(2.), dtype=x.dtype, name="half_sqrt_2")
w = x * half_sqrt_2
z = math_ops.abs(w)
y = array_ops.where(math_ops.less(z, half_sqrt_2),
@@ -190,18 +189,18 @@ def _ndtri(p):
def _create_polynomial(var, coeffs):
"""Compute n_th order polynomial via Horner's method."""
- if not coeffs:
- return 0.
+ coeffs = np.array(coeffs, var.dtype.as_numpy_dtype)
+ if not coeffs.size:
+ return array_ops.zeros_like(var)
return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var
- maybe_complement_p = array_ops.where(p > 1. - np.exp(-2.), 1. - p, p)
+ maybe_complement_p = array_ops.where(p > -np.expm1(-2.), 1. - p, p)
# Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
# later on. The result from the computation when p == 0 is not used so any
# number that doesn't result in NaNs is fine.
- one_half = constant_op.constant(0.5, dtype=p.dtype)
sanitized_mcp = array_ops.where(
maybe_complement_p <= 0.,
- array_ops.fill(array_ops.shape(p), one_half),
+ array_ops.fill(array_ops.shape(p), np.array(0.5, p.dtype.as_numpy_dtype)),
maybe_complement_p)
# Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
@@ -216,10 +215,12 @@ def _ndtri(p):
# arrays based on whether p < exp(-32).
z = math_ops.sqrt(-2. * math_ops.log(sanitized_mcp))
first_term = z - math_ops.log(z) / z
- second_term_small_p = (_create_polynomial(1. / z, p2)
- / _create_polynomial(1. / z, q2)) / z
- second_term_otherwise = (_create_polynomial(1. / z, p1)
- / _create_polynomial(1. / z, q1)) / z
+ second_term_small_p = (
+ _create_polynomial(1. / z, p2) /
+ _create_polynomial(1. / z, q2) / z)
+ second_term_otherwise = (
+ _create_polynomial(1. / z, p1) /
+ _create_polynomial(1. / z, q1) / z)
x_for_small_p = first_term - second_term_small_p
x_otherwise = first_term - second_term_otherwise
@@ -330,23 +331,25 @@ def _log_ndtr_lower(x, series_order):
"""Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`."""
x_2 = math_ops.square(x)
# Log of the term multiplying (1 + sum)
- log_scale = -0.5 * x_2 - math_ops.log(-x) - 0.5 * math.log(2. * math.pi)
+ log_scale = -0.5 * x_2 - math_ops.log(-x) - 0.5 * np.log(2. * np.pi)
return log_scale + math_ops.log(_log_ndtr_asymptotic_series(x, series_order))
def _log_ndtr_asymptotic_series(x, series_order):
"""Calculates the asymptotic series used in log_ndtr."""
+ dtype = x.dtype.as_numpy_dtype
if series_order <= 0:
- return 1.
+ return np.array(1, dtype)
x_2 = math_ops.square(x)
- even_sum = 0.
- odd_sum = 0.
+ even_sum = array_ops.zeros_like(x)
+ odd_sum = array_ops.zeros_like(x)
x_2n = x_2 # Start with x^{2*1} = x^{2*n} with n = 1.
for n in range(1, series_order + 1):
+ y = np.array(_double_factorial(2 * n - 1), dtype) / x_2n
if n % 2:
- odd_sum += _double_factorial(2 * n - 1) / x_2n
+ odd_sum += y
else:
- even_sum += _double_factorial(2 * n - 1) / x_2n
+ even_sum += y
x_2n *= x_2
return 1. + even_sum - odd_sum
diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py
index 961b07a7bd..9330b930b5 100644
--- a/tensorflow/python/ops/distributions/student_t.py
+++ b/tensorflow/python/ops/distributions/student_t.py
@@ -157,7 +157,7 @@ class StudentT(distribution.Distribution):
Raises:
TypeError: if loc and scale are different dtypes.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[df, loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(df)]
if validate_args else []):
@@ -349,7 +349,7 @@ class StudentTWithAbsDfSoftplusScale(StudentT):
validate_args=False,
allow_nan_stats=True,
name="StudentTWithAbsDfSoftplusScale"):
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[df, scale]) as name:
super(StudentTWithAbsDfSoftplusScale, self).__init__(
df=math_ops.floor(math_ops.abs(df)),
diff --git a/tensorflow/python/ops/distributions/transformed_distribution.py b/tensorflow/python/ops/distributions/transformed_distribution.py
index bc321900dc..9392464ec1 100644
--- a/tensorflow/python/ops/distributions/transformed_distribution.py
+++ b/tensorflow/python/ops/distributions/transformed_distribution.py
@@ -252,7 +252,7 @@ class TransformedDistribution(distribution_lib.Distribution):
name: Python `str` name prefixed to Ops created by this class. Default:
`bijector.name + distribution.name`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
name = name or (("" if bijector is None else bijector.name) +
distribution.name)
with ops.name_scope(name, values=[event_shape, batch_shape]) as name:
diff --git a/tensorflow/python/ops/distributions/uniform.py b/tensorflow/python/ops/distributions/uniform.py
index 087797c653..dfa10331e3 100644
--- a/tensorflow/python/ops/distributions/uniform.py
+++ b/tensorflow/python/ops/distributions/uniform.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
@@ -102,7 +103,7 @@ class Uniform(distribution.Distribution):
Raises:
InvalidArgumentError: if `low >= high` and `validate_args=False`.
"""
- parameters = locals()
+ parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[low, high]) as name:
with ops.control_dependencies([
check_ops.assert_less(
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 3afa85fda0..59c89d21f9 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.util import tf_inspect
def assert_close(
@@ -1297,6 +1298,43 @@ def pad(x, axis, front=False, back=False, value=0, count=1, name=None):
return x
+def parent_frame_arguments():
+ """Returns parent frame arguments.
+
+ When called inside a function, returns a dictionary with the caller's function
+ arguments. These are positional arguments and keyword arguments (**kwargs),
+ while variable arguments (*varargs) are excluded.
+
+ When called at global scope, this will return an empty dictionary, since there
+ are no arguments.
+
+ WARNING: If caller function argument names are overloaded before invoking
+ this method, then values will reflect the overloaded value. For this reason,
+ we recommend calling `parent_frame_arguments` at the beginning of the
+ function.
+ """
+ # All arguments and the names used for *varargs, and **kwargs
+ arg_names, variable_arg_name, keyword_arg_name, local_vars = (
+ tf_inspect._inspect.getargvalues( # pylint: disable=protected-access
+ # Get the first frame of the caller of this method.
+ tf_inspect._inspect.stack()[1][0])) # pylint: disable=protected-access
+
+ # Remove the *varargs, and flatten the **kwargs. Both are
+ # nested lists.
+ local_vars.pop(variable_arg_name, {})
+ keyword_args = local_vars.pop(keyword_arg_name, {})
+
+ final_args = {}
+ # Copy over arguments and their values. In general, local_vars
+ # may contain more than just the arguments, since this method
+ # can be called anywhere in a function.
+ for arg_name in arg_names:
+ final_args[arg_name] = local_vars.pop(arg_name)
+ final_args.update(keyword_args)
+
+ return final_args
+
+
class AppendDocstring(object):
"""Helper class to promote private subclass docstring to public counterpart.
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index c8a1500e76..fe463fa823 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -843,7 +843,9 @@ def _ForUsingWhile(start,
return (i + 1, n, start, delta) + tuple(for_result) + extra_args
if hostmem is not None:
- hostmem = [(4 + _) for _ in hostmem]
+ hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem]
+ else:
+ hostmem = [0, 1, 2, 3]
results = While(
input_=[0, n, start, delta] + inputs + WhileBody.captured_inputs,
diff --git a/tensorflow/python/ops/image_grad.py b/tensorflow/python/ops/image_grad.py
index 9f43e3f146..102181e68b 100644
--- a/tensorflow/python/ops/image_grad.py
+++ b/tensorflow/python/ops/image_grad.py
@@ -107,16 +107,20 @@ def _CropAndResizeGrad(op, grad):
allowed_types = [dtypes.float16, dtypes.float32, dtypes.float64]
if op.inputs[0].dtype in allowed_types:
# pylint: disable=protected-access
- grad0 = gen_image_ops.crop_and_resize_grad_image(grad,
- op.inputs[1],
- op.inputs[2],
- image_shape,
- T=op.get_attr("T"))
+ grad0 = gen_image_ops.crop_and_resize_grad_image(
+ grad, op.inputs[1], op.inputs[2], image_shape, T=op.get_attr("T"),
+ method=op.get_attr("method"))
# pylint: enable=protected-access
else:
grad0 = None
- grad1 = gen_image_ops.crop_and_resize_grad_boxes(grad, op.inputs[0],
- op.inputs[1], op.inputs[2])
+ # `grad0` is the gradient to the input image pixels and it
+ # has been implemented for nearest neighbor and bilinear sampling
+ # respectively. `grad1` is the gradient to the input crop boxes' coordinates.
+ # When using nearest neighbor sampling, the gradient to crop boxes'
+ # coordinates are not well defined. In practice, we still approximate
+ # grad1 using the gradient derived from bilinear sampling.
+ grad1 = gen_image_ops.crop_and_resize_grad_boxes(
+ grad, op.inputs[0], op.inputs[1], op.inputs[2])
return [grad0, grad1, None, None]
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index bd5b2ae83b..54e27b87df 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -1772,6 +1772,7 @@ def non_max_suppression(boxes,
scores,
max_output_size,
iou_threshold=0.5,
+ score_threshold=0.0,
name=None):
"""Greedily selects a subset of bounding boxes in descending order of score.
@@ -1800,6 +1801,8 @@ def non_max_suppression(boxes,
of boxes to be selected by non max suppression.
iou_threshold: A float representing the threshold for deciding whether boxes
overlap too much with respect to IOU.
+ score_threshold: A float representing the threshold for deciding when to
+ remove boxes based on score.
name: A name for the operation (optional).
Returns:
@@ -1808,8 +1811,10 @@ def non_max_suppression(boxes,
"""
with ops.name_scope(name, 'non_max_suppression'):
iou_threshold = ops.convert_to_tensor(iou_threshold, name='iou_threshold')
- return gen_image_ops.non_max_suppression_v2(boxes, scores, max_output_size,
- iou_threshold)
+ score_threshold = ops.convert_to_tensor(
+ score_threshold, name='score_threshold')
+ return gen_image_ops.non_max_suppression_v3(boxes, scores, max_output_size,
+ iou_threshold, score_threshold)
_rgb_to_yiq_kernel = [[0.299, 0.59590059,
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index f93bf0a17f..1f8d8dc4f3 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -488,9 +488,9 @@ class Orthogonal(Initializer):
If the shape of the tensor to initialize is two-dimensional, it is initialized
with an orthogonal matrix obtained from the QR decomposition of a matrix of
- uniform random numbers. If the matrix has fewer rows than columns then the
- output will have orthogonal rows. Otherwise, the output will have orthogonal
- columns.
+ random numbers drawn from a normal distribution.
+ If the matrix has fewer rows than columns then the output will have orthogonal
+ rows. Otherwise, the output will have orthogonal columns.
If the shape of the tensor to initialize is more than two-dimensional,
a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`
diff --git a/tensorflow/python/ops/linalg/linear_operator_kronecker.py b/tensorflow/python/ops/linalg/linear_operator_kronecker.py
index da959f9a1c..1fd5073c17 100644
--- a/tensorflow/python/ops/linalg/linear_operator_kronecker.py
+++ b/tensorflow/python/ops/linalg/linear_operator_kronecker.py
@@ -381,10 +381,6 @@ class LinearOperatorKronecker(linear_operator.LinearOperator):
else:
matrix_dimensions = [self.range_dimension, column_dim]
- print("x: ", x)
- print("bathc_shape:", self.batch_shape)
- print("self.shape:", self.shape)
- print("output: ", output)
output.set_shape(broadcast_batch_shape.concatenate(
matrix_dimensions))
diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py
index 7e4fb6a6fc..1b5bb9470c 100644
--- a/tensorflow/python/ops/linalg/linear_operator_test_util.py
+++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py
@@ -178,7 +178,9 @@ class LinearOperatorDerivedClassTest(test.TestCase):
SkipTest Exception, if test_name is in self._tests_to_skip.
"""
if test_name in self._tests_to_skip:
- self.skipTest("%s skipped because it was added to self._tests_to_skip.")
+ self.skipTest(
+ "{} skipped because it was added to self._tests_to_skip.".format(
+ test_name))
def test_to_dense(self):
self._skip_if_tests_to_skip_contains("to_dense")
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index cd07550d2e..09a4425436 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -2100,11 +2100,10 @@ def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
Args:
value: A 4-D `Tensor` of shape `[batch, height, width, channels]` and type
`float32`, `float64`, `qint8`, `quint8`, or `qint32`.
- ksize: A 1-D int Tensor of 4 elements.
- The size of the window for each dimension of the input tensor.
- strides: A 1-D int Tensor of 4 elements
- The stride of the sliding window for each dimension of the
- input tensor.
+ ksize: A list or tuple of 4 ints. The size of the window for each dimension
+ of the input tensor.
+ strides: A list or tuple of 4 ints. The stride of the sliding window for
+ each dimension of the input tensor.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
See the @{tf.nn.convolution$comment here}
data_format: A string. 'NHWC' and 'NCHW' are supported.
@@ -2130,10 +2129,10 @@ def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
Args:
value: A 4-D `Tensor` of the format specified by `data_format`.
- ksize: A 1-D int Tensor of 4 elements. The size of the window for
+ ksize: A list or tuple of 4 ints. The size of the window for each dimension
+ of the input tensor.
+ strides: A list or tuple of 4 ints. The stride of the sliding window for
each dimension of the input tensor.
- strides: A 1-D int Tensor of 4 elements. The stride of the sliding
- window for each dimension of the input tensor.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
See the @{tf.nn.convolution$comment here}
data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index e94ad90dfd..c77a18d890 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -1401,6 +1401,13 @@ def static_state_saving_rnn(cell,
outputs[-1] = nest.pack_sequence_as(
structure=last_output, flat_sequence=flat_last_output)
+ if state_is_tuple:
+ state = nest.pack_sequence_as(
+ structure=state,
+ flat_sequence=[array_ops.identity(s) for s in flat_state])
+ else:
+ state = array_ops.identity(state)
+
return (outputs, state)
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 67f753485b..68d22794d3 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -1005,6 +1005,8 @@ class DropoutWrapper(RNNCell):
# Set cell, variational_recurrent, seed before running the code below
self._cell = cell
+ if isinstance(cell, checkpointable.CheckpointableBase):
+ self._track_checkpointable(self._cell, name="cell")
self._variational_recurrent = variational_recurrent
self._seed = seed
@@ -1152,6 +1154,8 @@ class ResidualWrapper(RNNCell):
and outputs.
"""
self._cell = cell
+ if isinstance(cell, checkpointable.CheckpointableBase):
+ self._track_checkpointable(self._cell, name="cell")
self._residual_fn = residual_fn
@property
@@ -1207,6 +1211,8 @@ class DeviceWrapper(RNNCell):
device: A device string or function, for passing to `tf.device`.
"""
self._cell = cell
+ if isinstance(cell, checkpointable.CheckpointableBase):
+ self._track_checkpointable(self._cell, name="cell")
self._device = device
@property
@@ -1322,7 +1328,7 @@ class MultiRNNCell(RNNCell):
return cur_inp, new_states
-class _SlimRNNCell(RNNCell):
+class _SlimRNNCell(RNNCell, checkpointable.NotCheckpointable):
"""A simple wrapper for slim.rnn_cells."""
def __init__(self, cell_fn):
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 9f58c6a476..baf169b687 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -39,6 +39,8 @@ from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
+# Expose regex_full_match in strings namespace
+tf_export("strings.regex_full_match")(regex_full_match)
@tf_export("string_split")
def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index 9b6b8c508f..b46c46d871 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -295,42 +295,6 @@ class Template(checkpointable.CheckpointableBase):
# which is not the same as whether the scope has been created.
self._variables_created = False
- @property
- def _checkpoint_dependencies(self):
- """Sanity checking for object-based saving.
-
- Does not override Checkpointable dependency tracking, but checks that
- variables accessible through Checkpointable dependencies on other `Template`
- objects include all of the variable_scope-filtered `Template.variables`.
-
- Returns:
- A list of checkpointable.CheckpointableReference objects.
- Raises:
- ValueError: If this object is not compatible with object-based saving.
- """
- dependencies = super(Template, self)._checkpoint_dependencies
- dependency_variables = []
- for _, dependency in dependencies:
- if isinstance(dependency, Template):
- dependency_variables.extend(dependency.variables)
- else:
- dependency_variables.append(dependency)
- dependency_variables = set(dependency_variables)
- not_included_variables = []
- for expected_variable in sorted(self.variables, key=lambda v: v.name):
- if expected_variable not in dependency_variables:
- not_included_variables.append(expected_variable)
- if not_included_variables:
- # Trying to save a Template which improperly tracks its variables.
- raise ValueError(
- ("The Template '%s' references variables which are not included via "
- "object-based dependency tracking. Most likely a custom "
- "getter/creator was registered which does not call Template's "
- "custom variable creator (which is responsible for tracking "
- "dependencies).\n\nExpected these variables to be dependencies: %s")
- % (self, not_included_variables))
- return dependencies
-
def _checkpointable_custom_creator(self, next_creator, name, initial_value,
checkpointable_parent=None, **kwargs):
"""A variable creation hook which adds Checkpointable dependencies.
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index adb0f59948..d79d8c8bab 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -32,7 +32,6 @@ from six import iteritems
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -41,6 +40,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
@@ -422,7 +422,7 @@ class _VariableStore(object):
"use_resource": use_resource,
}
# `fn_args` can handle functions, `functools.partial`, `lambda`.
- if "constraint" in estimator_util.fn_args(custom_getter):
+ if "constraint" in function_utils.fn_args(custom_getter):
custom_getter_kwargs["constraint"] = constraint
return custom_getter(**custom_getter_kwargs)
else:
@@ -840,7 +840,8 @@ class _VariableStore(object):
initializing_from_value = False
# If dtype is DT_INT/DT_UINT, provide a default value `zero`
# If dtype is DT_BOOL, provide a default value `FALSE`
- elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
+ elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool
+ or dtype == dtypes.string):
initializer = init_ops.zeros_initializer()
initializing_from_value = False
# NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py
index d00312a1f3..a57bcaea69 100644
--- a/tensorflow/python/training/checkpointable.py
+++ b/tensorflow/python/training/checkpointable.py
@@ -18,14 +18,21 @@ from __future__ import division
from __future__ import print_function
import collections
+import functools
+import json
+import weakref
from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import saveable_object
from tensorflow.python.util import nest
+from tensorflow.python.util import serialization
# Key where the object graph proto is saved in a TensorBundle
@@ -37,6 +44,7 @@ OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
# the object has no dependencies, then its value may be restored on object
# creation (avoiding double assignment when executing eagerly).
VARIABLE_VALUE_KEY = "VARIABLE_VALUE"
+OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"
CheckpointableReference = collections.namedtuple(
"CheckpointableReference",
@@ -85,6 +93,35 @@ class CheckpointInitialValue(ops.Tensor):
return self._checkpoint_position
+class PythonStringStateSaveable(saveable_object.SaveableObject):
+ """Saves Python state in a checkpoint."""
+
+ def __init__(self, name, state_callback):
+ """Configure saving.
+
+ Args:
+ name: The checkpoint key to write to.
+ state_callback: A function taking no arguments which returns a
+ string. This function is run every time a checkpoint is written.
+ """
+ if context.executing_eagerly():
+ self._save_string = (
+ lambda: constant_op.constant(state_callback(), dtype=dtypes.string))
+ else:
+ self._save_string = constant_op.constant("", dtype=dtypes.string)
+ self.feed_dict_additions = (
+ lambda: {self._save_string: state_callback()})
+ spec = saveable_object.SaveSpec(
+ self._save_string, "", name, dtype=dtypes.string)
+ super(PythonStringStateSaveable, self).__init__(
+ self._save_string, [spec], name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ # TODO(allenl): Add a Python hook for state coming out of a checkpoint
+ # (currently PythonStringStateSaveable is write-only).
+ return control_flow_ops.no_op()
+
+
class _CheckpointPosition(object):
"""Indicates a position within a `_Checkpoint`."""
@@ -604,7 +641,6 @@ class CheckpointableBase(object):
# restoration on to our dependencies.
if checkpoint.restore_uid > self._update_uid:
restore_ops = checkpoint_position.restore_ops()
- # TODO(allenl): Get a list of feeds for saving Python state
self._update_uid = checkpoint.restore_uid
else:
restore_ops = ()
@@ -656,7 +692,24 @@ class CheckpointableBase(object):
lambda name="global_name_for_this_object":
SaveableObject(name=name, ...)}
"""
- return {}
+ if not hasattr(self, "get_config"):
+ return {}
+ try:
+ self.get_config()
+ except NotImplementedError:
+ return {}
+ weak_self = weakref.ref(self)
+ def _state_callback():
+ dereferenced_self = weak_self()
+ if dereferenced_self:
+ return json.dumps(self,
+ default=serialization.get_json_type,
+ sort_keys=True).encode("utf8")
+ else:
+ return ""
+ return {OBJECT_CONFIG_JSON_KEY: functools.partial(
+ PythonStringStateSaveable,
+ state_callback=_state_callback)}
class NoDependency(object):
@@ -684,6 +737,17 @@ class NoDependency(object):
self.value = value
+class NotCheckpointable(object):
+ """Marks instances of child classes as unsaveable using an object-based API.
+
+ Useful for marking objects which would otherwise look checkpointable because
+ of inheritance (e.g. through `Layer`) as not checkpointable. Inheriting from
+ `NotCheckpointable` does not prevent an object from being assigned to any
+ attributes, but will throw an error on save/restore.
+ """
+ pass
+
+
class Checkpointable(CheckpointableBase):
"""Manages dependencies on other objects.
diff --git a/tensorflow/python/training/checkpointable_utils.py b/tensorflow/python/training/checkpointable_utils.py
index f2a2b411fd..72be434fb2 100644
--- a/tensorflow/python/training/checkpointable_utils.py
+++ b/tensorflow/python/training/checkpointable_utils.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import checkpointable as checkpointable_lib
from tensorflow.python.training import optimizer as optimizer_lib
+from tensorflow.python.training import saveable_object
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -204,6 +205,12 @@ def _breadth_first_checkpointable_traversal(root_checkpointable):
path_to_root = {root_checkpointable: ()}
while to_visit:
current_checkpointable = to_visit.popleft()
+ if isinstance(current_checkpointable, checkpointable_lib.NotCheckpointable):
+ raise NotImplementedError(
+ ("The object %s does not support object-based saving. File a feature "
+ "request if this limitation bothers you. In the meantime, you can "
+ "remove the dependency on this object and save everything else.")
+ % (current_checkpointable,))
current_checkpointable._maybe_initialize_checkpointable() # pylint: disable=protected-access
bfs_sorted.append(current_checkpointable)
for child_checkpointable in (
@@ -303,42 +310,93 @@ def _serialize_slot_variables(checkpointable_objects, node_ids, object_names):
def _serialize_checkpointables(
- checkpointable_objects, node_ids, object_names, slot_variables):
+ checkpointable_objects, node_ids, object_names, slot_variables,
+ saveables_cache):
"""Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
- named_saveables = {}
-
+ named_saveables = []
+ feed_additions = {}
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
assert node_ids[checkpointable] == checkpoint_id
object_proto = object_graph_proto.nodes.add()
object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
object_name = object_names[checkpointable]
+ if saveables_cache is not None:
+ cached_attributes = saveables_cache.setdefault(checkpointable, {})
+ else:
+ cached_attributes = None
for name, saveable_factory in (
checkpointable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access
attribute = object_proto.attributes.add()
attribute.name = name
attribute.checkpoint_key = "%s/%s/%s" % (
object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
- if callable(saveable_factory):
- saveable = saveable_factory(name=attribute.checkpoint_key)
+ if cached_attributes is None:
+ saveables = None
else:
- saveable = saveable_factory
- # Figure out the name-based Saver's name for this variable.
- saver_dict = saver_lib.BaseSaverBuilder.OpListToDict(
- [saveable], convert_variable_to_tensor=False)
- attribute.full_name, = saver_dict.keys()
- named_saveables[attribute.checkpoint_key] = saveable
+ saveables = cached_attributes.get(name, None)
+ if saveables is not None:
+ for saveable in saveables:
+ if attribute.checkpoint_key not in saveable.name:
+ # The checkpoint key for this SaveableObject is different. We need
+ # to re-create it.
+ saveables = None
+ del cached_attributes[name]
+ break
+ if saveables is None:
+ if callable(saveable_factory):
+ maybe_saveable = saveable_factory(name=attribute.checkpoint_key)
+ else:
+ maybe_saveable = saveable_factory
+ if isinstance(maybe_saveable, saveable_object.SaveableObject):
+ saveables = (maybe_saveable,)
+ else:
+ # Figure out the name-based Saver's name for this variable. If it's
+ # already a SaveableObject we'd just get the checkpoint key back, so
+ # we leave full_name blank.
+ saver_dict = saver_lib.BaseSaverBuilder.OpListToDict(
+ [maybe_saveable], convert_variable_to_tensor=False)
+ full_name, = saver_dict.keys()
+ saveables = tuple(saver_lib.BaseSaverBuilder.SaveableObjectsForOp(
+ op=maybe_saveable, name=attribute.checkpoint_key))
+ for saveable in saveables:
+ saveable.full_name = full_name
+ for saveable in saveables:
+ if attribute.checkpoint_key not in saveable.name:
+ raise AssertionError(
+ ("The object %s produced a SaveableObject with name '%s' for "
+ "attribute '%s'. Expected a name containing '%s'.")
+ % (checkpointable, name, saveable.name,
+ attribute.checkpoint_key))
+ if cached_attributes is not None:
+ cached_attributes[name] = saveables
+
+ for saveable in saveables:
+ if hasattr(saveable, "full_name"):
+ attribute.full_name = saveable.full_name
+ saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None)
+ if saveable_feed_dict_fn is not None:
+ saveable_feed_dict = saveable_feed_dict_fn() # pylint: disable=not-callable
+ for new_feed_key in saveable_feed_dict.keys():
+ if new_feed_key in feed_additions:
+ raise AssertionError(
+ ("The object %s tried to feed a value for the Tensor %s "
+ "when saving, but another object is already feeding a "
+ "value.")
+ % (checkpointable, new_feed_key))
+ feed_additions.update(saveable_feed_dict)
+ named_saveables.extend(saveables)
for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
child_proto = object_proto.children.add()
child_proto.node_id = node_ids[child.ref]
child_proto.local_name = child.name
- return named_saveables, object_graph_proto
+ return named_saveables, object_graph_proto, feed_additions
-def _serialize_object_graph(root_checkpointable):
+def _serialize_object_graph(root_checkpointable, saveables_cache):
"""Determine checkpoint keys for variables and build a serialized graph.
Non-slot variables are keyed based on a shortest path from the root saveable
@@ -351,12 +409,17 @@ def _serialize_object_graph(root_checkpointable):
Args:
root_checkpointable: A `Checkpointable` object whose variables (including
the variables of dependencies, recursively) should be saved.
+ saveables_cache: A dictionary mapping `Checkpointable` objects -> attribute
+ names -> SaveableObjects, used to avoid re-creating SaveableObjects when
+ graph building.
Returns:
- A tuple of (named_variables, object_graph_proto):
+ A tuple of (named_variables, object_graph_proto, feed_additions):
named_variables: A dictionary mapping names to variable objects.
object_graph_proto: A CheckpointableObjectGraph protocol buffer containing
the serialized object graph and variable references.
+ feed_additions: A dictionary mapping from Tensors to values which should
+ be fed when saving.
Raises:
ValueError: If there are invalid characters in an optimizer's slot names.
@@ -376,7 +439,8 @@ def _serialize_object_graph(root_checkpointable):
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
object_names=object_names,
- slot_variables=slot_variables)
+ slot_variables=slot_variables,
+ saveables_cache=saveables_cache)
def list_objects(root_checkpointable):
@@ -728,6 +792,14 @@ class CheckpointableSaver(object):
self._last_restore_object_graph = None
self._last_restore_checkpoint = None
+ if context.executing_eagerly():
+ # SaveableObjects are always recreated when executing eagerly.
+ self._saveable_object_cache = None
+ else:
+ # Maps Checkpointable objects -> attribute names -> SaveableObjects, to
+ # avoid re-creating SaveableObjects when graph building.
+ self._saveable_object_cache = weakref.WeakKeyDictionary()
+
@property
def _root_checkpointable(self):
if isinstance(self._root_checkpointable_ref, weakref.ref):
@@ -759,8 +831,9 @@ class CheckpointableSaver(object):
Returns:
The full path to the checkpoint.
"""
- named_variables, graph_proto = _serialize_object_graph(
- self._root_checkpointable)
+ named_variables, graph_proto, feed_additions = _serialize_object_graph(
+ self._root_checkpointable,
+ saveables_cache=self._saveable_object_cache)
if not context.executing_eagerly():
if session is None:
session = ops.get_default_session()
@@ -769,15 +842,15 @@ class CheckpointableSaver(object):
self._object_graph_feed_tensor = constant_op.constant(
"", dtype=dtypes.string)
object_graph_tensor = self._object_graph_feed_tensor
- feed_additions = {object_graph_tensor: graph_proto.SerializeToString()}
+ feed_additions.update(
+ {object_graph_tensor: graph_proto.SerializeToString()})
else:
session = None
with ops.device("/cpu:0"):
object_graph_tensor = constant_op.constant(
graph_proto.SerializeToString(), dtype=dtypes.string)
- feed_additions = None
assert checkpointable_lib.OBJECT_GRAPH_PROTO_KEY not in named_variables
- named_variables[checkpointable_lib.OBJECT_GRAPH_PROTO_KEY] = (
+ named_variables.append(
_NoRestoreSaveable(
tensor=object_graph_tensor,
name=checkpointable_lib.OBJECT_GRAPH_PROTO_KEY))
@@ -804,13 +877,23 @@ class CheckpointableSaver(object):
def _global_variable_names(self):
"""Generate a `tf.train.Saver`-style `var_list` using `variable.name`s."""
- named_saveables, graph_proto = _serialize_object_graph(
- self._root_checkpointable)
+ named_saveables, graph_proto, _ = _serialize_object_graph(
+ self._root_checkpointable,
+ # We destructively modify SaveableObjects, so don't do any caching.
+ saveables_cache=None)
+ named_saveables = {v.name: v for v in named_saveables}
saver_names = {}
for object_proto in graph_proto.nodes:
for attribute_proto in object_proto.attributes:
- saver_names[attribute_proto.full_name] = named_saveables[
- attribute_proto.checkpoint_key]
+ if attribute_proto.full_name:
+ # Ignore attributes, such as Python object JSON, which don't have a
+ # name-based Saver name.
+ saveable = named_saveables[attribute_proto.checkpoint_key]
+ saveable.name = attribute_proto.full_name
+ for spec in saveable.specs:
+ spec.name = spec.name.replace(attribute_proto.checkpoint_key,
+ attribute_proto.full_name)
+ saver_names[attribute_proto.full_name] = saveable
return saver_names
def restore(self, save_path):
@@ -1037,6 +1120,7 @@ class Checkpoint(checkpointable_lib.Checkpointable):
% (v,))
setattr(self, k, v)
self._save_counter = None # Created lazily for restore-on-create.
+ self._save_assign_op = None
self._saver = CheckpointableSaver(weakref.ref(self))
def _maybe_create_save_counter(self):
@@ -1089,10 +1173,13 @@ class Checkpoint(checkpointable_lib.Checkpointable):
# needs to be initialized before assign_add. This is only an issue if
# restore() has not been called first.
session.run(self.save_counter.initializer)
- with ops.colocate_with(self.save_counter):
- assign_op = self.save_counter.assign_add(1)
+ if not in_graph_mode or self._save_assign_op is None:
+ with ops.colocate_with(self.save_counter):
+ assign_op = self.save_counter.assign_add(1, read_value=False)
+ if in_graph_mode:
+ self._save_assign_op = assign_op
if in_graph_mode:
- session.run(assign_op)
+ session.run(self._save_assign_op)
return self._saver.save(
file_prefix=file_prefix,
checkpoint_number=self.save_counter,
diff --git a/tensorflow/python/training/checkpointable_utils_test.py b/tensorflow/python/training/checkpointable_utils_test.py
index 3b8166bf37..d94cdcfc06 100644
--- a/tensorflow/python/training/checkpointable_utils_test.py
+++ b/tensorflow/python/training/checkpointable_utils_test.py
@@ -17,10 +17,12 @@ from __future__ import division
from __future__ import print_function
import functools
+import json
import os
import six
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@@ -120,7 +122,8 @@ class InterfaceTests(test.TestCase):
# The .name attribute may be globally influenced, but the checkpoint name
# won't be (tested below).
self.assertEqual("duplicate_1:0", duplicate.name)
- named_variables, _ = checkpointable_utils._serialize_object_graph(obj)
+ named_variables, _, _ = checkpointable_utils._serialize_object_graph(
+ obj, saveables_cache=None)
expected_checkpoint_names = (
"a_variable/.ATTRIBUTES/VARIABLE_VALUE",
"bare_initializer/.ATTRIBUTES/VARIABLE_VALUE",
@@ -129,7 +132,7 @@ class InterfaceTests(test.TestCase):
"ones_initializer/.ATTRIBUTES/VARIABLE_VALUE",
)
six.assertCountEqual(
- self, expected_checkpoint_names, named_variables.keys())
+ self, expected_checkpoint_names, [v.name for v in named_variables])
def testInitNotCalled(self):
@@ -171,6 +174,27 @@ class InterfaceTests(test.TestCase):
all_variable_names.append(attribute.full_name)
self.assertIn("dense/kernel", all_variable_names)
+ def testNotCheckpointable(self):
+
+ class CallsFunctionalStuff(
+ checkpointable.NotCheckpointable, checkpointable.Checkpointable):
+ pass
+
+ test_dir = self.get_temp_dir()
+ prefix = os.path.join(test_dir, "ckpt")
+ checkpoint = checkpointable_utils.Checkpoint(x=CallsFunctionalStuff())
+ with self.assertRaises(NotImplementedError):
+ checkpoint.save(prefix)
+
+ class CallsFunctionalStuffOtherMRO(
+ checkpointable.Checkpointable, checkpointable.NotCheckpointable):
+ pass
+
+ checkpoint_reversed = checkpointable_utils.Checkpoint(
+ x=CallsFunctionalStuffOtherMRO())
+ with self.assertRaises(NotImplementedError):
+ checkpoint_reversed.save(prefix)
+
class _MirroringSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
@@ -245,8 +269,9 @@ class CheckpointingTests(test.TestCase):
self.evaluate(checkpointable_utils.gather_initializers(
root_checkpointable))
self.evaluate(train_op)
- named_variables, serialized_graph = (
- checkpointable_utils._serialize_object_graph(root_checkpointable))
+ named_variables, serialized_graph, _ = (
+ checkpointable_utils._serialize_object_graph(
+ root_checkpointable, saveables_cache=None))
expected_checkpoint_names = (
# Created in the root node, so no prefix.
"optimizer_step",
@@ -269,24 +294,29 @@ class CheckpointingTests(test.TestCase):
suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
expected_checkpoint_names = [
name + suffix for name in expected_checkpoint_names]
+ # The Dense layers also save get_config() JSON
+ expected_checkpoint_names.extend(
+ ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
+ "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"])
+ named_variables = {v.name: v for v in named_variables}
six.assertCountEqual(self, expected_checkpoint_names,
named_variables.keys())
# Check that we've mapped to the right variable objects (not exhaustive)
self.assertEqual(
- "global_step:0",
- named_variables["optimizer_step" + suffix].name)
+ "global_step",
+ named_variables["optimizer_step" + suffix].full_name)
self.assertEqual(
- "my_model/dense_1/kernel:0",
- named_variables["model/_second/kernel" + suffix].name)
+ "my_model/dense_1/kernel",
+ named_variables["model/_second/kernel" + suffix].full_name)
self.assertEqual(
- "my_model/dense/kernel:0",
- named_variables["model/_named_dense/kernel" + suffix].name)
+ "my_model/dense/kernel",
+ named_variables["model/_named_dense/kernel" + suffix].full_name)
self.assertEqual(
- "beta1_power:0",
- named_variables["optimizer/beta1_power" + suffix].name)
+ "beta1_power",
+ named_variables["optimizer/beta1_power" + suffix].full_name)
self.assertEqual(
- "beta2_power:0",
- named_variables["optimizer/beta2_power" + suffix].name)
+ "beta2_power",
+ named_variables["optimizer/beta2_power" + suffix].full_name)
# Spot check the generated protocol buffers.
self.assertEqual("optimizer",
serialized_graph.nodes[0].children[1].local_name)
@@ -311,7 +341,7 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(
"my_model/dense/kernel/Adam:0",
optimizer.get_slot(
- var=named_variables["model/_named_dense/kernel" + suffix],
+ var=model._named_dense.kernel,
name="m").name)
self.assertEqual(
"model/_named_dense/kernel" + suffix,
@@ -563,11 +593,11 @@ class CheckpointingTests(test.TestCase):
root = checkpointable.Checkpointable()
checkpointable_utils.add_variable(
root, name=name, shape=[1, 2], dtype=dtypes.float64)
- named_variables, _ = checkpointable_utils._serialize_object_graph(root)
- checkpoint_name, = named_variables.keys()
- with ops.name_scope("root/" + checkpoint_name):
+ (named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
+ root, saveables_cache=None)
+ with ops.name_scope("root/" + named_variable.name):
pass # Make sure we can use this as an op name if we prefix it.
- return checkpoint_name
+ return named_variable.name
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testVariableNameEscaping(self):
@@ -585,9 +615,9 @@ class CheckpointingTests(test.TestCase):
leaf = checkpointable.Checkpointable()
root.leaf = leaf
checkpointable_utils.add_variable(leaf, name="v", shape=[])
- named_variables, _ = checkpointable_utils._serialize_object_graph(root)
- variable_name, = named_variables.keys()
- self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", variable_name)
+ (named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
+ root, saveables_cache=None)
+ self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", named_variable.name)
@test_util.run_in_graph_and_eager_modes()
def testLocalNameValidation(self):
@@ -596,9 +626,10 @@ class CheckpointingTests(test.TestCase):
# Dots are escaped, which avoids conflicts with reserved names.
root._track_checkpointable(leaf, name=".ATTRIBUTES")
checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[])
- named_variables, _ = checkpointable_utils._serialize_object_graph(root)
- name, = named_variables.keys()
- self.assertEqual(name, "..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE")
+ (named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
+ root, saveables_cache=None)
+ self.assertEqual("..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE",
+ named_variable.name)
def testAnonymousVarsInInit(self):
@@ -1219,14 +1250,20 @@ class TemplateTests(test.TestCase):
def _templated():
v = variable_scope.get_variable(
- "v", shape=[1], initializer=init_ops.zeros_initializer())
+ "v", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
v2 = variable_scope.get_variable(
- "v2", shape=[1], initializer=init_ops.zeros_initializer())
+ "v2", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
return v, v + 1., v2
save_template = template.make_template("s1", _templated)
- save_root = checkpointable_utils.Checkpoint(my_template=save_template)
v1_save, _, v2_save = save_template()
+ optimizer = adam.AdamOptimizer(0.0)
+ save_root = checkpointable_utils.Checkpoint(
+ my_template=save_template, optimizer=optimizer)
+ optimizer.minimize(v1_save.read_value)
+ self.evaluate([v.initializer for v in optimizer.variables()])
self.evaluate(v1_save.assign([12.]))
self.evaluate(v2_save.assign([14.]))
checkpoint_directory = self.get_temp_dir()
@@ -1234,9 +1271,12 @@ class TemplateTests(test.TestCase):
save_path = save_root.save(checkpoint_prefix)
load_template = template.make_template("s2", _templated)
- load_root = checkpointable_utils.Checkpoint(my_template=load_template)
+ load_optimizer = adam.AdamOptimizer(0.0)
+ load_root = checkpointable_utils.Checkpoint(
+ my_template=load_template, optimizer=load_optimizer)
status = load_root.restore(save_path)
var, var_plus_one, var2 = load_template()
+ load_optimizer.minimize(var.read_value)
self.assertEqual(2, len(load_template._checkpoint_dependencies))
self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
@@ -1395,5 +1435,48 @@ class CheckpointCompatibilityTests(test.TestCase):
root.restore(save_path).assert_consumed().run_restore_ops()
self._check_sentinels(root)
+
+class PythonMetadataTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testSaveLoad(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ dense = core.Dense(1)
+ checkpoint = checkpointable_utils.Checkpoint(dense=dense)
+ dense(constant_op.constant([[1.]]))
+ checkpoint.restore(None).initialize_or_restore()
+ save_path = checkpoint.save(checkpoint_prefix)
+
+ def _get_dense_node_from_object_graph(object_graph_proto):
+ root_node = object_graph_proto.nodes[0]
+ for child in root_node.children:
+ if child.local_name == "dense":
+ break
+ else:
+ raise AssertionError(
+ "Expected a 'dense' dependency of root, didn't find one.")
+ dense_node = object_graph_proto.nodes[child.node_id] # pylint: disable=undefined-loop-variable
+ self.assertEqual(1, len(dense_node.attributes))
+ reader = pywrap_tensorflow.NewCheckpointReader(save_path)
+ layer_json = reader.get_tensor(dense_node.attributes[0].checkpoint_key)
+ return json.loads(layer_json.decode("utf-8"))
+
+ layer_data = _get_dense_node_from_object_graph(
+ checkpointable_utils.object_metadata(save_path))
+ self.assertEqual("Dense", layer_data["class_name"])
+ self.assertEqual(1, layer_data["config"]["units"])
+
+ # Check that no new ops are added to the graph the second time we save.
+ ops.get_default_graph().finalize()
+
+ dense.units = 42
+ save_path = checkpoint.save(checkpoint_prefix)
+ layer_data = _get_dense_node_from_object_graph(
+ checkpointable_utils.object_metadata(save_path))
+ self.assertEqual("Dense", layer_data["class_name"])
+ self.assertEqual(42, layer_data["config"]["units"])
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index f584a009d9..fece3370f3 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -25,7 +25,6 @@ import sys
import six
from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.estimator import util
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -41,6 +40,7 @@ from tensorflow.python.training import queue_runner
from tensorflow.python.training import saver as training_saver
from tensorflow.python.training import session_manager as sm
from tensorflow.python.training import session_run_hook
+from tensorflow.python.util import function_utils
from tensorflow.python.util.tf_export import tf_export
@@ -620,7 +620,7 @@ class _MonitoredSession(object):
`step_context`. It may also optionally have `self` for cases when it
belongs to an object.
"""
- step_fn_arguments = util.fn_args(step_fn)
+ step_fn_arguments = function_utils.fn_args(step_fn)
if step_fn_arguments != ('step_context',) and step_fn_arguments != (
'self',
'step_context',
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 66914bacf3..a676ef9a12 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -1175,7 +1175,16 @@ class Optimizer(
variable_key = _var_key(variable)
slot_variable = named_slots.get(variable_key, None)
if (slot_variable is None and context.executing_eagerly() and
- slot_variable_position.is_simple_variable()):
+ slot_variable_position.is_simple_variable()
+ # Defer slot variable creation if there is an active variable creator
+ # scope. Generally we'd like to eagerly create/restore slot variables
+ # when possible, but this may mean that scopes intended to catch
+ # `variable` also catch its eagerly created slot variable
+ # unintentionally (specifically make_template would add a dependency on
+ # a slot variable if not for this case). Deferring is mostly harmless
+ # (aside from double initialization), and makes variable creator scopes
+ # behave the same way they do when graph building.
+ and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
initializer = checkpointable.CheckpointInitialValue(
checkpoint_position=slot_variable_position)
slot_variable = self._get_or_make_slot(
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 8134fd74aa..d11502ff15 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -569,6 +569,76 @@ class BaseSaverBuilder(object):
# pylint: enable=protected-access
return names_to_saveables
+ @staticmethod
+ def SaveableObjectsForOp(op, name):
+ """Create `SaveableObject`s from an operation.
+
+ Args:
+ op: A variable, operation, or SaveableObject to coerce into a
+ SaveableObject.
+ name: A string name for the SaveableObject.
+
+ Yields:
+ `SaveableObject`s which together save/restore `op`.
+
+ Raises:
+ TypeError: If `name` is not a string.
+ ValueError: For operations with no known conversion to SaveableObject.
+ """
+ if not isinstance(name, six.string_types):
+ raise TypeError(
+ "names_to_saveables must be a dict mapping string names to "
+ "checkpointable operations. Name is not a string: %s" % name)
+ if isinstance(op, BaseSaverBuilder.SaveableObject):
+ yield op
+ elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
+ if isinstance(op, variables.PartitionedVariable):
+ op = list(op)
+ # A set of slices.
+ slice_name = None
+ # pylint: disable=protected-access
+ for variable in op:
+ if not isinstance(variable, variables.Variable):
+ raise ValueError("Slices must all be Variables: %s" % variable)
+ if not variable._save_slice_info:
+ raise ValueError("Slices must all be slices: %s" % variable)
+ if slice_name is None:
+ slice_name = variable._save_slice_info.full_name
+ elif slice_name != variable._save_slice_info.full_name:
+ raise ValueError(
+ "Slices must all be from the same tensor: %s != %s" %
+ (slice_name, variable._save_slice_info.full_name))
+ if variable.op.type in ["Variable", "VariableV2",
+ "AutoReloadVariable"]:
+ yield BaseSaverBuilder.VariableSaveable(
+ variable, variable._save_slice_info.spec, name)
+ else:
+ yield BaseSaverBuilder.ResourceVariableSaveable(
+ variable, variable._save_slice_info.spec, name)
+ # pylint: enable=protected-access
+ else:
+ # A variable or tensor.
+ if context.executing_eagerly():
+ if not isinstance(op, resource_variable_ops.ResourceVariable):
+ raise ValueError("Can only save/restore ResourceVariable eager "
+ "mode is enabled, type: %s." % type(op))
+ yield BaseSaverBuilder.ResourceVariableSaveable(op, "", name)
+ else:
+ if isinstance(op, resource_variable_ops.ResourceVariable):
+ variable = op._graph_element # pylint: disable=protected-access
+ else:
+ variable = ops.internal_convert_to_tensor(op, as_ref=True)
+ if not BaseSaverBuilder._IsVariable(variable):
+ raise TypeError("names_to_saveables must be a dict mapping string "
+ "names to Tensors/Variables. Not a variable: %s" %
+ variable)
+ if variable.op.type in ["Variable", "VariableV2",
+ "AutoReloadVariable"]:
+ yield BaseSaverBuilder.VariableSaveable(variable, "", name)
+ else:
+ yield BaseSaverBuilder.ResourceVariableSaveable(
+ variable, "", name)
+
def _ValidateAndSliceInputs(self, names_to_saveables):
"""Returns the variables and names that will be used for a Saver.
@@ -590,63 +660,11 @@ class BaseSaverBuilder(object):
saveables = []
seen_ops = set()
- for name in sorted(names_to_saveables.keys()):
- if not isinstance(name, six.string_types):
- raise TypeError(
- "names_to_saveables must be a dict mapping string names to "
- "checkpointable operations. Name is not a string: %s" % name)
- op = names_to_saveables[name]
- if isinstance(op, BaseSaverBuilder.SaveableObject):
- self._AddSaveable(saveables, seen_ops, op)
- elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
- if isinstance(op, variables.PartitionedVariable):
- op = list(op)
- # A set of slices.
- slice_name = None
- # pylint: disable=protected-access
- for variable in op:
- if not isinstance(variable, variables.Variable):
- raise ValueError("Slices must all be Variables: %s" % variable)
- if not variable._save_slice_info:
- raise ValueError("Slices must all be slices: %s" % variable)
- if slice_name is None:
- slice_name = variable._save_slice_info.full_name
- elif slice_name != variable._save_slice_info.full_name:
- raise ValueError(
- "Slices must all be from the same tensor: %s != %s" %
- (slice_name, variable._save_slice_info.full_name))
- if variable.op.type in ["Variable", "VariableV2",
- "AutoReloadVariable"]:
- saveable = BaseSaverBuilder.VariableSaveable(
- variable, variable._save_slice_info.spec, name)
- else:
- saveable = BaseSaverBuilder.ResourceVariableSaveable(
- variable, variable._save_slice_info.spec, name)
- self._AddSaveable(saveables, seen_ops, saveable)
- # pylint: enable=protected-access
- else:
- # A variable or tensor.
- if context.executing_eagerly():
- if not isinstance(op, resource_variable_ops.ResourceVariable):
- raise ValueError("Can only save/restore ResourceVariable eager "
- "mode is enabled, type: %s." % type(op))
- saveable = BaseSaverBuilder.ResourceVariableSaveable(op, "", name)
- else:
- if isinstance(op, resource_variable_ops.ResourceVariable):
- variable = op._graph_element # pylint: disable=protected-access
- else:
- variable = ops.internal_convert_to_tensor(op, as_ref=True)
- if not BaseSaverBuilder._IsVariable(variable):
- raise TypeError("names_to_saveables must be a dict mapping string "
- "names to Tensors/Variables. Not a variable: %s" %
- variable)
- if variable.op.type in ["Variable", "VariableV2",
- "AutoReloadVariable"]:
- saveable = BaseSaverBuilder.VariableSaveable(variable, "", name)
- else:
- saveable = BaseSaverBuilder.ResourceVariableSaveable(
- variable, "", name)
- self._AddSaveable(saveables, seen_ops, saveable)
+ for name, op in sorted(names_to_saveables.items(),
+ # Avoid comparing ops, sort only by name.
+ key=lambda x: x[0]):
+ for converted_saveable_object in self.SaveableObjectsForOp(op, name):
+ self._AddSaveable(saveables, seen_ops, converted_saveable_object)
return saveables
def _AddSaveable(self, saveables, seen_ops, saveable):
diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py
new file mode 100644
index 0000000000..7bbbde3cd2
--- /dev/null
+++ b/tensorflow/python/util/function_utils.py
@@ -0,0 +1,57 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility to retrieve function args."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
+
+
+def _is_bounded_method(fn):
+ _, fn = tf_decorator.unwrap(fn)
+ return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
+
+
+def _is_callable_object(obj):
+ return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__)
+
+
+def fn_args(fn):
+ """Get argument names for function-like object.
+
+ Args:
+ fn: Function, or function-like object (e.g., result of `functools.partial`).
+
+ Returns:
+ `tuple` of string argument names.
+
+ Raises:
+ ValueError: if partial function has positionally bound arguments
+ """
+ if isinstance(fn, functools.partial):
+ args = fn_args(fn.func)
+ args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
+ else:
+ if _is_callable_object(fn):
+ fn = fn.__call__
+ args = tf_inspect.getfullargspec(fn).args
+ if _is_bounded_method(fn):
+ args.remove('self')
+ return tuple(args)
diff --git a/tensorflow/python/estimator/util_test.py b/tensorflow/python/util/function_utils_test.py
index 4b2c8d7637..e78cf6a5b0 100644
--- a/tensorflow/python/estimator/util_test.py
+++ b/tensorflow/python/util/function_utils_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import functools
-from tensorflow.python.estimator import util
from tensorflow.python.platform import test
+from tensorflow.python.util import function_utils
class FnArgsTest(test.TestCase):
@@ -29,7 +29,7 @@ class FnArgsTest(test.TestCase):
def test_simple_function(self):
def fn(a, b):
return a + b
- self.assertEqual(('a', 'b'), util.fn_args(fn))
+ self.assertEqual(('a', 'b'), function_utils.fn_args(fn))
def test_callable(self):
@@ -38,7 +38,7 @@ class FnArgsTest(test.TestCase):
def __call__(self, a, b):
return a + b
- self.assertEqual(('a', 'b'), util.fn_args(Foo()))
+ self.assertEqual(('a', 'b'), function_utils.fn_args(Foo()))
def test_bounded_method(self):
@@ -47,7 +47,7 @@ class FnArgsTest(test.TestCase):
def bar(self, a, b):
return a + b
- self.assertEqual(('a', 'b'), util.fn_args(Foo().bar))
+ self.assertEqual(('a', 'b'), function_utils.fn_args(Foo().bar))
def test_partial_function(self):
expected_test_arg = 123
@@ -59,7 +59,7 @@ class FnArgsTest(test.TestCase):
wrapped_fn = functools.partial(fn, test_arg=123)
- self.assertEqual(('a',), util.fn_args(wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
def test_partial_function_with_positional_args(self):
expected_test_arg = 123
@@ -71,7 +71,7 @@ class FnArgsTest(test.TestCase):
wrapped_fn = functools.partial(fn, 123)
- self.assertEqual(('a',), util.fn_args(wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
self.assertEqual(3, wrapped_fn(3))
self.assertEqual(3, wrapped_fn(a=3))
@@ -88,7 +88,7 @@ class FnArgsTest(test.TestCase):
wrapped_fn = functools.partial(fn, test_arg2=456)
double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
- self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
def test_double_partial_with_positional_args_in_outer_layer(self):
expected_test_arg1 = 123
@@ -102,7 +102,7 @@ class FnArgsTest(test.TestCase):
wrapped_fn = functools.partial(fn, test_arg2=456)
double_wrapped_fn = functools.partial(wrapped_fn, 123)
- self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
self.assertEqual(3, double_wrapped_fn(3))
self.assertEqual(3, double_wrapped_fn(a=3))
@@ -119,7 +119,7 @@ class FnArgsTest(test.TestCase):
wrapped_fn = functools.partial(fn, 123) # binds to test_arg1
double_wrapped_fn = functools.partial(wrapped_fn, 456) # binds to test_arg2
- self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
self.assertEqual(3, double_wrapped_fn(3))
self.assertEqual(3, double_wrapped_fn(a=3))
diff --git a/tensorflow/python/util/serialization.py b/tensorflow/python/util/serialization.py
new file mode 100644
index 0000000000..faf5164faa
--- /dev/null
+++ b/tensorflow/python/util/serialization.py
@@ -0,0 +1,64 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for serializing Python objects."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import tensor_shape
+
+
+def get_json_type(obj):
+ """Serializes any object to a JSON-serializable structure.
+
+ Arguments:
+ obj: the object to serialize
+
+ Returns:
+ JSON-serializable structure representing `obj`.
+
+ Raises:
+ TypeError: if `obj` cannot be serialized.
+ """
+ # if obj is a serializable Keras class instance
+ # e.g. optimizer, layer
+ if hasattr(obj, 'get_config'):
+ return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
+
+ # if obj is any numpy type
+ if type(obj).__module__ == np.__name__:
+ if isinstance(obj, np.ndarray):
+ return {'type': type(obj), 'value': obj.tolist()}
+ else:
+ return obj.item()
+
+ # misc functions (e.g. loss function)
+ if callable(obj):
+ return obj.__name__
+
+ # if obj is a python 'type'
+ if type(obj).__name__ == type.__name__:
+ return obj.__name__
+
+ if isinstance(obj, tensor_shape.Dimension):
+ return obj.value
+
+ if isinstance(obj, tensor_shape.TensorShape):
+ return obj.as_list()
+
+ raise TypeError('Not JSON Serializable:', obj)
diff --git a/tensorflow/python/util/serialization_test.py b/tensorflow/python/util/serialization_test.py
new file mode 100644
index 0000000000..f16fa5377b
--- /dev/null
+++ b/tensorflow/python/util/serialization_test.py
@@ -0,0 +1,76 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for serialization functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras._impl.keras.engine import input_layer
+from tensorflow.python.keras._impl.keras.engine import sequential
+from tensorflow.python.keras._impl.keras.engine import training
+from tensorflow.python.keras._impl.keras.layers import core
+from tensorflow.python.platform import test
+from tensorflow.python.util import serialization
+
+
+class SerializationTests(test.TestCase):
+
+ def test_serialize_dense(self):
+ dense = core.Dense(3)
+ dense(constant_op.constant([[4.]]))
+ round_trip = json.loads(json.dumps(
+ dense, default=serialization.get_json_type))
+ self.assertEqual(3, round_trip["config"]["units"])
+
+ def test_serialize_shape(self):
+ round_trip = json.loads(json.dumps(
+ tensor_shape.TensorShape([None, 2, 3]),
+ default=serialization.get_json_type))
+ self.assertIs(round_trip[0], None)
+ self.assertEqual(round_trip[1], 2)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_serialize_sequential(self):
+ model = sequential.Sequential()
+ model.add(core.Dense(4))
+ model.add(core.Dense(5))
+ model(constant_op.constant([[1.]]))
+ sequential_round_trip = json.loads(
+ json.dumps(model, default=serialization.get_json_type))
+ self.assertEqual(5, sequential_round_trip["config"][1]["config"]["units"])
+ input_round_trip = json.loads(
+ json.dumps(model._input_layers, default=serialization.get_json_type))
+ self.assertAllEqual([1, 1],
+ input_round_trip[0]["config"]["batch_input_shape"])
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_serialize_model(self):
+ x = input_layer.Input(shape=[3])
+ y = core.Dense(10)(x)
+ model = training.Model(x, y)
+ model(constant_op.constant([[1., 1., 1.]]))
+ model_round_trip = json.loads(
+ json.dumps(model, default=serialization.get_json_type))
+ self.assertEqual(
+ 10, model_round_trip["config"]["layers"][1]["config"]["units"])
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
index 663036de8a..9bad4a2481 100644
--- a/tensorflow/python/util/tf_inspect.py
+++ b/tensorflow/python/util/tf_inspect.py
@@ -116,7 +116,7 @@ def getcallargs(func, *positional, **named):
it. If no attached decorators modify argspec, the final unwrapped target's
argspec will be used.
"""
- argspec = getargspec(func)
+ argspec = getfullargspec(func)
call_args = named.copy()
this = getattr(func, 'im_self', None) or getattr(func, '__self__', None)
if ismethod(func) and this:
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index be0b0bf5fb..ea87744b22 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -1086,6 +1086,13 @@ class BlasSupport {
virtual bool DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
+ float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
+ virtual bool DoBlasGemmBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha,
const port::ArraySlice<DeviceMemory<float> *> &a, int lda,
const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,
const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
@@ -1948,6 +1955,13 @@ class BlasSupport {
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, float alpha, \
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, \
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, \
+ float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, \
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
+ bool DoBlasGemmBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, \
const port::ArraySlice<DeviceMemory<float> *> &a, int lda, \
const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, \
const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, \
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 3c1353aee3..b8ec424844 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -292,6 +292,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasGetMathMode)
STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode)
#endif
+#if CUDA_VERSION >= 9010
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx)
+#endif
+
} // namespace wrap
static string ToString(cublasStatus_t status) {
@@ -628,7 +632,7 @@ template <typename FuncT, typename... Args>
bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
bool pointer_mode_host, bool err_on_failure,
bool use_tensor_op_math, Args... args) {
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
CHECK(blas_ != nullptr);
if (!SetStream(stream)) {
@@ -2342,13 +2346,23 @@ bool CUDABlas::DoBlasGemmWithAlgorithm(
computation_type, algorithm, output_profile_result);
}
-template <typename T, typename FuncT>
+template <typename T>
+struct HalfAsFloat {
+ typedef T type;
+};
+
+template <>
+struct HalfAsFloat<Eigen::half> {
+ typedef float type;
+};
+
+template <typename T, typename Scalar, typename FuncT>
port::Status CUDABlas::DoBlasGemmBatchedInternal(
FuncT cublas_func, Stream *stream, blas::Transpose transa,
- blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha,
const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
- T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
+ Scalar beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
for (int i = 0; i < batch_count; ++i) {
@@ -2357,7 +2371,7 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
c_raw_ptrs.push_back(static_cast<T *>(c_ptrs_to_wrappers[i]->opaque()));
}
- typedef typename CUDAComplexT<T>::type CUDA_T;
+ typedef typename HalfAsFloat<typename CUDAComplexT<T>::type>::type CUDA_T;
const size_t size = batch_count * sizeof(CUDA_T *);
@@ -2409,18 +2423,84 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
"CUDABlas::DoBlasGemmBatched");
}
- bool ok = DoBlasInternal(
- cublas_func, stream, true /* = pointer_mode_host */,
- CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
- CUDAComplex(&alpha), const_cast<const CUDA_T **>(CUDAMemory(a)), lda,
- const_cast<const CUDA_T **>(CUDAMemory(b)), ldb, CUDAComplex(&beta),
- const_cast<CUDA_T **>(CUDAMemory(c)), ldc, batch_count);
+ cudaDataType_t data_type = CUDADataType<T>::type;
- if (ok) {
+#if CUDA_VERSION >= 9010
+ int cc_major, cc_minor;
+ if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
+ &cc_major, &cc_minor) &&
+ cc_major >= 5) {
+ bool use_tensor_ops = TensorOpMathEnabled() && data_type == CUDA_R_16F;
+ cublasGemmAlgo_t algo =
+ (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
+ cudaDataType_t compute_type =
+ (data_type == CUDA_R_16F ? CUDA_R_32F : data_type);
+ const void **a_void_ptrs = reinterpret_cast<const void **>(
+ const_cast<const CUDA_T **>(CUDAMemory(a)));
+ const void **b_void_ptrs = reinterpret_cast<const void **>(
+ const_cast<const CUDA_T **>(CUDAMemory(b)));
+ void **c_void_ptrs =
+ reinterpret_cast<void **>(const_cast<CUDA_T **>(CUDAMemory(c)));
+ bool ok;
+ ok = DoBlasInternalImpl(
+ wrap::cublasGemmBatchedEx, stream, true /* = pointer_mode_host */,
+ true /* = err_on_failure */, use_tensor_ops, CUDABlasTranspose(transa),
+ CUDABlasTranspose(transb), m, n, k, &alpha, a_void_ptrs, data_type, lda,
+ b_void_ptrs, data_type, ldb, &beta, c_void_ptrs, data_type, ldc,
+ batch_count, compute_type, algo);
+ if (ok) {
+ return port::Status::OK();
+ }
+ return port::Status(port::error::INTERNAL,
+ "failed BLAS call, see log for details");
+ }
+#endif
+ // either CUDA_VERSION < 9.1 or SM < 5.0
+ if (data_type != CUDA_R_16F) {
+ bool ok = DoBlasInternal(
+ cublas_func, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ CUDAComplex(&alpha), const_cast<const CUDA_T **>(CUDAMemory(a)), lda,
+ const_cast<const CUDA_T **>(CUDAMemory(b)), ldb, CUDAComplex(&beta),
+ const_cast<CUDA_T **>(CUDAMemory(c)), ldc, batch_count);
+ if (ok) {
+ return port::Status::OK();
+ }
+ return port::Status(port::error::INTERNAL,
+ "failed BLAS call, see log for details");
+ } else {
+ // Fall back to a loop for fp16
+ for (int b = 0; b < batch_count; ++b) {
+ const DeviceMemory<T> &a_matrix = *a_ptrs_to_wrappers[b];
+ const DeviceMemory<T> &b_matrix = *b_ptrs_to_wrappers[b];
+ DeviceMemory<T> *c_matrix = c_ptrs_to_wrappers[b];
+ bool ok = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a_matrix,
+ lda, b_matrix, ldb, beta, c_matrix, ldc);
+ if (!ok) {
+ return port::Status(port::error::INTERNAL,
+ "failed BLAS call, see log for details");
+ }
+ }
return port::Status::OK();
}
- return port::Status(port::error::INTERNAL,
- "failed BLAS call, see log for details");
+}
+
+bool CUDABlas::DoBlasGemmBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a_array, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b_array, int ldb,
+ float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c_array,
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+ // Note: The func passed here (cublasSgemmBatched) is not actually called,
+ // due to special handling of fp16 inside DoBlasGemmBatchedInternal.
+ port::Status status = DoBlasGemmBatchedInternal(
+ wrap::cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array,
+ lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
+ if (!status.ok()) {
+ LOG(ERROR) << status;
+ }
+ return status.ok();
}
bool CUDABlas::DoBlasGemmBatched(
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h
index 12dc5e47fd..42b3fde5b0 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.h
+++ b/tensorflow/stream_executor/cuda/cuda_blas.h
@@ -107,12 +107,12 @@ class CUDABlas : public blas::BlasSupport {
// A helper function to implement DoBlasGemmBatched interfaces for generic
// types.
- template <typename T, typename FuncT>
+ template <typename T, typename Scalar, typename FuncT>
port::Status DoBlasGemmBatchedInternal(
FuncT cublas_func, Stream *stream, blas::Transpose transa,
- blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k, Scalar alpha,
const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
- const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta,
+ const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, Scalar beta,
const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
int batch_count, ScratchAllocator *scratch_allocator);
diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
index feb529297e..46e5deed84 100644
--- a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
+++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
@@ -76,35 +76,36 @@ string DriverVersionStatusToString(port::StatusOr<DriverVersion> version) {
port::StatusOr<DriverVersion> StringToDriverVersion(const string &value) {
std::vector<string> pieces = port::Split(value, '.');
if (pieces.size() < 2 || pieces.size() > 4) {
- return port::Status{
+ return port::Status(
port::error::INVALID_ARGUMENT,
- port::Printf("expected %%d.%%d, %%d.%%d.%%d, or %%d.%%d.%%d.%%d form for driver version; got \"%s\"",
- value.c_str())};
+ port::Printf("expected %%d.%%d, %%d.%%d.%%d, or %%d.%%d.%%d.%%d form "
+ "for driver version; got \"%s\"",
+ value.c_str()));
}
int major;
int minor;
int patch = 0;
if (!port::safe_strto32(pieces[0], &major)) {
- return port::Status{
+ return port::Status(
port::error::INVALID_ARGUMENT,
port::Printf("could not parse major version number \"%s\" as an "
"integer from string \"%s\"",
- pieces[0].c_str(), value.c_str())};
+ pieces[0].c_str(), value.c_str()));
}
if (!port::safe_strto32(pieces[1], &minor)) {
- return port::Status{
+ return port::Status(
port::error::INVALID_ARGUMENT,
port::Printf("could not parse minor version number \"%s\" as an "
"integer from string \"%s\"",
- pieces[1].c_str(), value.c_str())};
+ pieces[1].c_str(), value.c_str()));
}
if (pieces.size() == 3 && !port::safe_strto32(pieces[2], &patch)) {
- return port::Status{
- port::error::INVALID_ARGUMENT,
- port::Printf("could not parse patch version number \"%s\" as an "
+ return port::Status(
+ port::error::INVALID_ARGUMENT,
+ port::Printf("could not parse patch version number \"%s\" as an "
"integer from string \"%s\"",
- pieces[2].c_str(), value.c_str())};
+ pieces[2].c_str(), value.c_str()));
}
DriverVersion result{major, minor, patch};
@@ -204,9 +205,9 @@ void Diagnostician::LogDiagnosticInformation() {
// Iterates through loaded DSOs with DlIteratePhdrCallback to find the
// driver-interfacing DSO version number. Returns it as a string.
port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
- port::StatusOr<DriverVersion> result{port::Status{
+ port::StatusOr<DriverVersion> result(port::Status(
port::error::NOT_FOUND,
- "was unable to find libcuda.so DSO loaded into this program"}};
+ "was unable to find libcuda.so DSO loaded into this program"));
#if defined(__APPLE__)
// OSX CUDA libraries have names like: libcuda_310.41.15_mercury.dylib
@@ -274,11 +275,11 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelModuleVersion(
static const char *kDriverFilePrelude = "Kernel Module ";
size_t offset = driver_version_file_contents.find(kDriverFilePrelude);
if (offset == string::npos) {
- return port::Status{
+ return port::Status(
port::error::NOT_FOUND,
port::StrCat("could not find kernel module information in "
"driver version file contents: \"",
- driver_version_file_contents, "\"")};
+ driver_version_file_contents, "\""));
}
string version_and_rest = driver_version_file_contents.substr(
@@ -334,25 +335,24 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelDriverVersion() {
return StringToDriverVersion(version);
}
CFRelease(kext_infos);
- auto status =
- port::Status{port::error::INTERNAL,
- port::StrCat("failed to read driver bundle version: ",
- CFStringGetCStringPtr(kDriverKextIdentifier, kCFStringEncodingUTF8))
- };
+ auto status = port::Status(
+ port::error::INTERNAL,
+ port::StrCat(
+ "failed to read driver bundle version: ",
+ CFStringGetCStringPtr(kDriverKextIdentifier, kCFStringEncodingUTF8)));
return status;
#elif defined(PLATFORM_WINDOWS)
auto status =
- port::Status{port::error::UNIMPLEMENTED,
- "kernel reported driver version not implemented on Windows"
- };
+ port::Status(port::error::UNIMPLEMENTED,
+ "kernel reported driver version not implemented on Windows");
return status;
#else
FILE *driver_version_file = fopen(kDriverVersionPath, "r");
if (driver_version_file == nullptr) {
- return port::Status{
+ return port::Status(
port::error::PERMISSION_DENIED,
port::StrCat("could not open driver version path for reading: ",
- kDriverVersionPath)};
+ kDriverVersionPath));
}
static const int kContentsSize = 1024;
@@ -371,11 +371,11 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelDriverVersion() {
return FindKernelModuleVersion(contents.begin());
}
- auto status =
- port::Status{port::error::INTERNAL,
- port::StrCat("failed to read driver version file contents: ",
- kDriverVersionPath, "; ferror: ",
- ferror(driver_version_file))};
+ auto status = port::Status(
+ port::error::INTERNAL,
+ port::StrCat(
+ "failed to read driver version file contents: ", kDriverVersionPath,
+ "; ferror: ", ferror(driver_version_file)));
fclose(driver_version_file);
return status;
#endif
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index af78efe81d..7ace7fd303 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -53,13 +53,6 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
namespace {
-// TODO(csigg): remove dnn namespace qualifier from the RNN code below.
-using ::stream_executor::dnn::BatchDescriptor;
-using ::stream_executor::dnn::ConvolutionDescriptor;
-using ::stream_executor::dnn::FilterDescriptor;
-using ::stream_executor::dnn::NormalizeDescriptor;
-using ::stream_executor::dnn::PoolingDescriptor;
-
// Converts (via narrowing) a type T value to a type U, and checks that the
// value has no value change due to the conversion.
template <typename WideT, typename NarrowT>
@@ -390,7 +383,7 @@ namespace {
// Turns a BatchDescriptor structure into a cudnn tensor handle within a scope.
class ScopedTensorDescriptor {
public:
- ScopedTensorDescriptor(const BatchDescriptor& batch_descriptor,
+ ScopedTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
cudnnDataType_t elem_type)
: handle_(nullptr) {
cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle_);
@@ -464,7 +457,7 @@ class ScopedTensorDescriptor {
// Turns a FilterDescriptor structure into a cudnn filter handle within a scope.
class ScopedFilterDescriptor {
public:
- ScopedFilterDescriptor(const FilterDescriptor& filter_descriptor,
+ ScopedFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
cudnnDataType_t elem_type)
: handle_(nullptr) {
cudnnStatus_t status = cudnnCreateFilterDescriptor(&handle_);
@@ -577,7 +570,7 @@ static bool BatchnormSpatialPersistentEnabled() {
class ScopedConvolutionDescriptor {
public:
ScopedConvolutionDescriptor(
- const ConvolutionDescriptor& convolution_descriptor,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
cudnnDataType_t data_type)
: handle_(nullptr) {
cudnnStatus_t status = cudnnCreateConvolutionDescriptor(&handle_);
@@ -671,7 +664,8 @@ class ScopedConvolutionDescriptor {
// within a scope.
class ScopedPoolingDescriptor {
public:
- explicit ScopedPoolingDescriptor(const PoolingDescriptor& pooling_descriptor)
+ explicit ScopedPoolingDescriptor(
+ const dnn::PoolingDescriptor& pooling_descriptor)
: handle_(nullptr) {
cudnnStatus_t status = cudnnCreatePoolingDescriptor(&handle_);
if (status != CUDNN_STATUS_SUCCESS) {
@@ -727,7 +721,7 @@ class ScopedPoolingDescriptor {
class ScopedNormalizeDescriptor {
public:
explicit ScopedNormalizeDescriptor(
- const NormalizeDescriptor& normalize_descriptor)
+ const dnn::NormalizeDescriptor& normalize_descriptor)
: handle_(nullptr) {
cudnnStatus_t status = cudnnCreateLRNDescriptor(&handle_);
if (status != CUDNN_STATUS_SUCCESS) {
@@ -1206,16 +1200,16 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
int dims[] = {1, rnn_desc.input_size(), 1};
int strides[] = {dims[1] * dims[2], dims[2], 1};
status = cudnnSetTensorNdDescriptor(
- /*tensorDesc=*/input_desc, rnn_desc.data_type() /*dataType*/,
- sizeof(dims) / sizeof(dims[0]) /*nbDims*/, /*dimA=*/dims,
+ /*tensorDesc=*/input_desc, /*dataType=*/rnn_desc.data_type(),
+ /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
/*strideA=*/strides);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to set tensor descriptor");
size_t params_size = 0;
status = cudnnGetRNNParamsSize(
- cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
+ /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*xDesc=*/input_desc, /*sizeInBytes=*/&params_size,
- rnn_desc.data_type() /*dataType*/);
+ /*dataType=*/rnn_desc.data_type());
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get RNN parameter size");
params_size_in_bytes_ = static_cast<int64>(params_size);
}
@@ -1226,8 +1220,8 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create RNN filter descriptor");
int dims[] = {static_cast<int>(params_size_in_bytes_), 1, 1};
status = cudnnSetFilterNdDescriptor(
- /*filterDesc=*/handle_, rnn_desc.data_type() /*dataType*/,
- /*format=*/CUDNN_TENSOR_NCHW, sizeof(dims) / sizeof(dims[0]) /*nbDims*/,
+ /*filterDesc=*/handle_, /*dataType=*/rnn_desc.data_type(),
+ /*format=*/CUDNN_TENSOR_NCHW, /*nbDims=*/sizeof(dims) / sizeof(dims[0]),
/*filterDimA=*/dims);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to update RNN filter descriptor");
}
@@ -1247,7 +1241,7 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
void* offset = nullptr;
if (type == 0) {
status = cudnnGetRNNLinLayerMatrixParams(
- cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
+ /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*layer=*/layer, /*xDesc=*/input_desc, /*wDesc=*/handle_,
/*w=*/nullptr, /*linLayerID=*/region,
/*linLayerMatDesc=*/region_desc_handle,
@@ -1256,7 +1250,7 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
status, "Cudnn fails to call cudnnGetRNNLinLayerMatrixParams");
} else {
status = cudnnGetRNNLinLayerBiasParams(
- cudnn.handle() /*rnnDesc*/, rnn_desc.handle() /*rnnDesc*/,
+ /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*layer=*/layer, /*xDesc=*/input_desc, /*wDesc=*/handle_,
/*w=*/nullptr, /*linLayerID=*/region,
/*linLayerBiasDesc=*/region_desc_handle,
@@ -1270,7 +1264,7 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
int n_dims;
status = cudnnGetFilterNdDescriptor(
/*filterDesc=*/region_desc_handle,
- sizeof(dims) / sizeof(dims[0]) /*nbDimsRequested*/,
+ /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
/*dataType=*/&data_type, /*format=*/&tensor_format,
/*nbDims=*/&n_dims, /*filterDimA=*/dims);
CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get filter description");
@@ -1338,7 +1332,7 @@ class CudnnRnnSequenceTensorDescriptor
int strides[] = {dims[1] * dims[2], dims[2], 1};
status = cudnnSetTensorNdDescriptor(
/*tensorDesc=*/handle, /*dataType=*/data_type,
- sizeof(dims) / sizeof(dims[0]) /*nbDims*/, /*dimA=*/dims,
+ /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
/*strideA=*/strides);
CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor");
// Replicate handle across the number of steps.
@@ -1390,7 +1384,7 @@ class CudnnRnnStateTensorDescriptor
int strides[] = {dims[1] * dims[2], dims[2], 1};
status = cudnnSetTensorNdDescriptor(
/*tensorDesc=*/handle_, /*dataType=*/data_type,
- sizeof(dims) / sizeof(dims[0]) /*nbDims*/, /*dimA=*/dims,
+ /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
/*strideA=*/strides);
CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor");
}
@@ -1497,9 +1491,9 @@ bool CheckRNNParameterSize(const CudnnHandle& cudnn,
const CudnnRnnSequenceTensorDescriptor& input_desc) {
size_t params_size_in_bytes = 0;
cudnnStatus_t status = cudnnGetRNNParamsSize(
- /*handle=*/cudnn.handle(), rnn_desc.handle() /*rnnDesc*/,
- input_desc.handles()[0] /*xDesc*/, /*sizeInBytes=*/&params_size_in_bytes,
- rnn_desc.data_type() /*dataType*/);
+ /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
+ /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/&params_size_in_bytes,
+ /*dataType=*/rnn_desc.data_type());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "Unable to check RNN param size: " << ToString(status);
return false;
@@ -1592,8 +1586,8 @@ bool CudnnSupport::DoRnnForwardImpl(
if (is_training) {
size_t reserve_space_size_in_bytes = 0;
cudnnStatus_t status = cudnnGetRNNTrainingReserveSize(
- cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
- /*seqLength=*/model_dims.seq_length, input_desc.handles() /*xDesc*/,
+ /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
+ /*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(),
/*sizeInBytes=*/&reserve_space_size_in_bytes);
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "Unable to query reserve space size: " << ToString(status);
@@ -1630,30 +1624,30 @@ bool CudnnSupport::DoRnnForwardImpl(
cudnnStatus_t status;
if (!is_training) {
status = cudnnRNNForwardInference(
- cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
- model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
- input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
- input_h_data.opaque() /*hx*/, input_c_desc.handle() /*cxDesc*/,
- input_c_data.opaque() /*cx*/, rnn_desc.params_handle() /*wDesc*/,
- params.opaque() /*w*/, output_desc.handles() /*yDesc*/,
- output_data->opaque() /*y*/, output_h_desc.handle() /*hyDesc*/,
- output_h_data->opaque() /*hy*/, output_c_desc.handle() /*cyDesc*/,
- output_c_data->opaque() /*cy*/, workspace.opaque() /*workspace*/,
- workspace.size() /*workSpaceSizeInBytes*/);
+ /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
+ /*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(),
+ /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
+ /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
+ /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
+ /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
+ /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
+ /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
+ /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
+ /*workSpaceSizeInBytes=*/workspace.size());
} else {
status = cudnnRNNForwardTraining(
- cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
- model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
- input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
- input_h_data.opaque() /*hx*/, input_c_desc.handle() /*cxDesc*/,
- input_c_data.opaque() /*cx*/, rnn_desc.params_handle() /*wDesc*/,
- params.opaque() /*w*/, output_desc.handles() /*yDesc*/,
- output_data->opaque() /*y*/, output_h_desc.handle() /*hyDesc*/,
- output_h_data->opaque() /*hy*/, output_c_desc.handle() /*cyDesc*/,
- output_c_data->opaque() /*cy*/, workspace.opaque() /*workspace*/,
- workspace.size() /*workSpaceSizeInBytes*/,
- reserve_space.opaque() /*reserveSpace*/,
- reserve_space.size() /*reserveSpaceSizeInBytes*/);
+ /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
+ /*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(),
+ /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
+ /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
+ /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
+ /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
+ /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
+ /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
+ /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
+ /*workSpaceSizeInBytes=*/workspace.size(),
+ /*reserveSpace=*/reserve_space.opaque(),
+ /*reserveSpaceSizeInBytes=*/reserve_space.size());
}
if (is_profiling) {
if (!timer->Stop(AsCUDAStream(stream))) {
@@ -1748,24 +1742,24 @@ bool CudnnSupport::DoRnnBackwardImpl(
}
// make the backward data call
cudnnStatus_t status = cudnnRNNBackwardData(
- cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
- model_dims.seq_length /*seqLength*/, output_desc.handles() /*yDesc*/,
- output_data.opaque() /*y*/, output_desc.handles() /*dyDesc*/,
- output_backprop_data.opaque() /*dy*/, output_h_desc.handle() /*dhyDesc*/,
- output_h_backprop_data.opaque() /*dhy*/,
- output_c_desc.handle() /*dcyDesc*/,
- output_c_backprop_data.opaque() /*dcy*/,
- rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/,
- input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
- input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/,
- input_desc.handles() /*dxDesc*/, input_backprop_data->opaque() /*dx*/,
- input_h_desc.handle() /*dhxDesc*/,
- input_h_backprop_data->opaque() /*dhx*/,
- input_c_desc.handle() /*dcxDesc*/,
- input_c_backprop_data->opaque() /*dcx*/, workspace.opaque() /*workspace*/,
- workspace.size() /*workSpaceSizeInBytes*/,
- reserve_space_data->opaque() /*reserveSpace*/,
- reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
+ /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
+ /*seqLength=*/model_dims.seq_length, /*yDesc=*/output_desc.handles(),
+ /*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(),
+ /*dy=*/output_backprop_data.opaque(), /*dhyDesc=*/output_h_desc.handle(),
+ /*dhy=*/output_h_backprop_data.opaque(),
+ /*dcyDesc=*/output_c_desc.handle(),
+ /*dcy=*/output_c_backprop_data.opaque(),
+ /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
+ /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
+ /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
+ /*dxDesc=*/input_desc.handles(), /*dx=*/input_backprop_data->opaque(),
+ /*dhxDesc=*/input_h_desc.handle(),
+ /*dhx=*/input_h_backprop_data->opaque(),
+ /*dcxDesc=*/input_c_desc.handle(),
+ /*dcx=*/input_c_backprop_data->opaque(), /*workspace=*/workspace.opaque(),
+ /*workSpaceSizeInBytes=*/workspace.size(),
+ /*reserveSpace=*/reserve_space_data->opaque(),
+ /*reserveSpaceSizeInBytes=*/reserve_space_data->size());
if (status != CUDNN_STATUS_SUCCESS) {
if (is_profiling) {
@@ -1780,16 +1774,16 @@ bool CudnnSupport::DoRnnBackwardImpl(
stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
// make the backward weight call
status = cudnnRNNBackwardWeights(
- cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
- model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
- input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
- input_h_data.opaque() /*hx*/, output_desc.handles() /*yDesc*/,
- output_data.opaque() /*y*/, workspace.opaque() /*workspace*/,
- workspace.size() /*workSpaceSizeInBytes*/,
- rnn_desc.params_handle() /*dwDesc*/,
- params_backprop_data->opaque() /*dw*/,
- reserve_space_data->opaque() /*reserveSpace*/,
- reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
+ /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
+ /*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(),
+ /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
+ /*hx=*/input_h_data.opaque(), /*yDesc=*/output_desc.handles(),
+ /*y=*/output_data.opaque(), /*workspace=*/workspace.opaque(),
+ /*workSpaceSizeInBytes=*/workspace.size(),
+ /*dwDesc=*/rnn_desc.params_handle(),
+ /*dw=*/params_backprop_data->opaque(),
+ /*reserveSpace=*/reserve_space_data->opaque(),
+ /*reserveSpaceSizeInBytes=*/reserve_space_data->size());
if (status != CUDNN_STATUS_SUCCESS) {
if (is_profiling) {
timer->Stop(AsCUDAStream(stream));
@@ -2415,12 +2409,12 @@ cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
template <class T>
bool CudnnSupport::DoConvolveImpl(
- Stream* stream, const BatchDescriptor& input_descriptor,
+ Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
- const FilterDescriptor& filter_descriptor,
+ const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<T>& filter_data,
- const ConvolutionDescriptor& convolution_descriptor,
- const BatchDescriptor& output_descriptor, DeviceMemory<T>* output_data,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& output_descriptor, DeviceMemory<T>* output_data,
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
@@ -3038,13 +3032,13 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl(
}
bool CudnnSupport::DoConvolve(
- Stream* stream, const BatchDescriptor& batch_descriptor,
+ Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
const DeviceMemory<float>& input_data,
- const FilterDescriptor& filter_descriptor,
+ const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<float>& filter_data,
- const ConvolutionDescriptor& convolution_descriptor,
- const BatchDescriptor& output_descriptor, DeviceMemory<float>* output_data,
- ScratchAllocator* scratch_allocator,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& output_descriptor,
+ DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
return DoConvolveImpl<float>(
@@ -3054,13 +3048,13 @@ bool CudnnSupport::DoConvolve(
}
bool CudnnSupport::DoConvolve(
- Stream* stream, const BatchDescriptor& batch_descriptor,
+ Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
const DeviceMemory<double>& input_data,
- const FilterDescriptor& filter_descriptor,
+ const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<double>& filter_data,
- const ConvolutionDescriptor& convolution_descriptor,
- const BatchDescriptor& output_descriptor, DeviceMemory<double>* output_data,
- ScratchAllocator* scratch_allocator,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& output_descriptor,
+ DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
return DoConvolveImpl<double>(
@@ -3070,12 +3064,12 @@ bool CudnnSupport::DoConvolve(
}
bool CudnnSupport::DoConvolve(
- Stream* stream, const BatchDescriptor& batch_descriptor,
+ Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
const DeviceMemory<Eigen::half>& input_data,
- const FilterDescriptor& filter_descriptor,
+ const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<Eigen::half>& filter_data,
- const ConvolutionDescriptor& convolution_descriptor,
- const BatchDescriptor& output_descriptor,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& output_descriptor,
DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
@@ -3202,7 +3196,8 @@ namespace {
template <class T>
DeviceMemory<T> MaybeTransformLayout(
Stream* stream, const CudnnHandle& cudnn,
- BatchDescriptor* output_descriptor, DeviceMemory<T> backward_output_data,
+ dnn::BatchDescriptor* output_descriptor,
+ DeviceMemory<T> backward_output_data,
std::unique_ptr<TemporaryDeviceMemory<T>>* transform_scratch) {
if (output_descriptor->layout() == dnn::DataLayout::kBatchDepthYX) {
return backward_output_data;
@@ -3211,7 +3206,7 @@ DeviceMemory<T> MaybeTransformLayout(
*transform_scratch =
stream->AllocateTemporaryArray<T>(backward_output_data.ElementCount())
.ConsumeValueOrDie();
- BatchDescriptor transformed_output_descriptor;
+ dnn::BatchDescriptor transformed_output_descriptor;
transformed_output_descriptor.CloneFrom(*output_descriptor);
transformed_output_descriptor.set_layout(dnn::DataLayout::kBatchDepthYX);
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
@@ -3263,12 +3258,12 @@ bool CudnnSupport::DoTransformTensor(Stream* stream,
template <class T>
bool CudnnSupport::DoConvolveBackwardDataImpl(
- Stream* stream, const FilterDescriptor& filter_descriptor,
+ Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<T>& filter_data,
- const BatchDescriptor& output_descriptor_in,
+ const dnn::BatchDescriptor& output_descriptor_in,
DeviceMemory<T> backward_output_data,
- const ConvolutionDescriptor& convolution_descriptor,
- const BatchDescriptor& input_descriptor,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& input_descriptor,
DeviceMemory<T>* backward_input_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
@@ -3287,7 +3282,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
auto cudnn = cudnn_->GetHandle(parent_, stream);
// TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass.
- BatchDescriptor output_descriptor;
+ dnn::BatchDescriptor output_descriptor;
output_descriptor.CloneFrom(output_descriptor_in);
std::unique_ptr<TemporaryDeviceMemory<T>> transform_scratch;
backward_output_data =
@@ -3475,12 +3470,12 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
}
bool CudnnSupport::DoConvolveBackwardData(
- Stream* stream, const FilterDescriptor& filter_descriptor,
+ Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<double>& filter_data,
- const BatchDescriptor& output_descriptor,
+ const dnn::BatchDescriptor& output_descriptor,
DeviceMemory<double> backward_output_data,
- const ConvolutionDescriptor& convolution_descriptor,
- const BatchDescriptor& input_descriptor,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& input_descriptor,
DeviceMemory<double>* backward_input_data,
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
@@ -3493,12 +3488,12 @@ bool CudnnSupport::DoConvolveBackwardData(
}
bool CudnnSupport::DoConvolveBackwardData(
- Stream* stream, const FilterDescriptor& filter_descriptor,
+ Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<float>& filter_data,
- const BatchDescriptor& output_descriptor,
+ const dnn::BatchDescriptor& output_descriptor,
DeviceMemory<float> backward_output_data,
- const ConvolutionDescriptor& convolution_descriptor,
- const BatchDescriptor& input_descriptor,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& input_descriptor,
DeviceMemory<float>* backward_input_data,
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
@@ -3511,12 +3506,12 @@ bool CudnnSupport::DoConvolveBackwardData(
}
bool CudnnSupport::DoConvolveBackwardData(
- Stream* stream, const FilterDescriptor& filter_descriptor,
+ Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<Eigen::half>& filter_data,
- const BatchDescriptor& output_descriptor,
+ const dnn::BatchDescriptor& output_descriptor,
DeviceMemory<Eigen::half> backward_output_data,
- const ConvolutionDescriptor& convolution_descriptor,
- const BatchDescriptor& input_descriptor,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& input_descriptor,
DeviceMemory<Eigen::half>* backward_input_data,
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
@@ -3554,7 +3549,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
auto cudnn = cudnn_->GetHandle(parent_, stream);
// TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass.
- BatchDescriptor output_descriptor;
+ dnn::BatchDescriptor output_descriptor;
output_descriptor.CloneFrom(output_descriptor_in);
std::unique_ptr<TemporaryDeviceMemory<T>> transform_scratch;
backward_output_data =
@@ -3826,27 +3821,27 @@ bool CudnnSupport::DoConvolveBackwardBiasImpl(
}
bool CudnnSupport::DoConvolveBackwardBias(
- Stream* stream, const BatchDescriptor& input_descriptor,
+ Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<double>& input_data,
- const BatchDescriptor& bias_descriptor,
+ const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<double>* backward_bias_data) {
return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
bias_descriptor, backward_bias_data);
}
bool CudnnSupport::DoConvolveBackwardBias(
- Stream* stream, const BatchDescriptor& input_descriptor,
+ Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<float>& input_data,
- const BatchDescriptor& bias_descriptor,
+ const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<float>* backward_bias_data) {
return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
bias_descriptor, backward_bias_data);
}
bool CudnnSupport::DoConvolveBackwardBias(
- Stream* stream, const BatchDescriptor& input_descriptor,
+ Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<Eigen::half>& input_data,
- const BatchDescriptor& bias_descriptor,
+ const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<Eigen::half>* backward_bias_data) {
return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
bias_descriptor, backward_bias_data);
@@ -3994,7 +3989,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
DeviceMemory<float>* output_data) {
ScopedTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
- BatchDescriptor bias_dimensions;
+ dnn::BatchDescriptor bias_dimensions;
bias_dimensions.set_count(1)
.set_feature_map_count(dimensions.feature_map_count())
.set_height(1)
@@ -4453,8 +4448,8 @@ bool CudnnSupport::DoMemcpyH2DQuantized(
}
bool CudnnSupport::DeriveOutputBatchDescriptor(
- const BatchDescriptor& batch_descriptor,
- const FilterDescriptor& filter_descriptor,
+ const dnn::BatchDescriptor& batch_descriptor,
+ const dnn::FilterDescriptor& filter_descriptor,
const dnn::ConvolutionDescriptor& convolution_descriptor,
dnn::BatchDescriptor* output_batch_descriptor) {
ScopedTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
@@ -4493,9 +4488,8 @@ void initialize_cudnn() {
cuda::CUDAExecutor* cuda_executor =
dynamic_cast<cuda::CUDAExecutor*>(parent);
if (cuda_executor == nullptr) {
- LOG(ERROR)
- << "Attempting to initialize an instance of the cuBLAS "
- << "support library with a non-CUDA StreamExecutor";
+ LOG(ERROR) << "Attempting to initialize an instance of the cuDNN "
+ << "support library with a non-CUDA StreamExecutor";
return nullptr;
}
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc
index 71cab145b9..e7e4192dfc 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.cc
+++ b/tensorflow/stream_executor/cuda/cuda_driver.cc
@@ -62,14 +62,14 @@ class CreatedContexts {
public:
// Returns whether context is a member of the live set.
static bool Has(CUcontext context) {
- tf_shared_lock lock{mu_};
+ tf_shared_lock lock(mu_);
return Live()->find(context) != Live()->end();
}
// Adds context to the live set.
static CudaContext* Add(CUcontext context) {
CHECK(context != nullptr);
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
auto cuda_context = new CudaContext(context, next_id_++);
Live()->insert(
std::make_pair(context, std::unique_ptr<CudaContext>(cuda_context)));
@@ -79,7 +79,7 @@ class CreatedContexts {
// Removes context from the live set.
static void Remove(CUcontext context) {
CHECK(context != nullptr);
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
auto it = Live()->find(context);
CHECK(it != Live()->end()) << context;
Live()->erase(it);
@@ -396,8 +396,8 @@ static port::Status InternalInit() {
LOG(ERROR) << "failed call to cuInit: " << ToString(res);
Diagnostician::LogDiagnosticInformation();
- return port::Status{port::error::ABORTED,
- port::StrCat("failed call to cuInit: ", ToString(res))};
+ return port::Status(port::error::ABORTED,
+ port::StrCat("failed call to cuInit: ", ToString(res)));
}
} // namespace
@@ -425,9 +425,9 @@ static port::Status InternalInit() {
return port::Status::OK();
}
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
- port::StrCat("failed call to cuDeviceGet: ", ToString(res))};
+ port::StrCat("failed call to cuDeviceGet: ", ToString(res)));
}
/* static */ bool CUDADriver::GetDeviceName(CUdevice device,
@@ -562,7 +562,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options,
}
}
- return port::Status{port::error::INTERNAL, message};
+ return port::Status(port::error::INTERNAL, message);
}
/* static */ void CUDADriver::DestroyContext(CudaContext* context) {
@@ -615,7 +615,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options,
/* static */ port::StatusOr<CUsharedconfig>
CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUsharedconfig shared_mem_config;
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult result = cuCtxGetSharedMemConfig(&shared_mem_config);
if (result != CUDA_SUCCESS) {
CUdevice device;
@@ -623,16 +623,16 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
LOG(ERROR) << "failed to get CUDA device shared memory config. "
<< "Context device ID: " << device
<< ", result: " << ToString(result);
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
- port::StrCat("failed to get shared memory config: ", ToString(result))};
+ port::StrCat("failed to get shared memory config: ", ToString(result)));
}
return shared_mem_config;
}
/* static */ port::Status CUDADriver::ContextSetSharedMemConfig(
CudaContext* context, CUsharedconfig shared_mem_config) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult result = cuCtxSetSharedMemConfig(shared_mem_config);
if (result != CUDA_SUCCESS) {
CUdevice device;
@@ -641,9 +641,9 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
<< "Context device ID: " << device
<< ", config: " << shared_mem_config
<< ", result: " << ToString(result);
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
- port::StrCat("failed to set shared memory config: ", ToString(result))};
+ port::StrCat("failed to set shared memory config: ", ToString(result)));
}
return port::Status::OK();
}
@@ -654,7 +654,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
unsigned int block_dim_y, unsigned int block_dim_z,
unsigned int shared_mem_bytes, CUstream stream, void **kernel_params,
void **extra) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
VLOG(2) << "launching kernel: " << function << "; gdx: " << grid_dim_x
<< " gdy: " << grid_dim_y << " gdz: " << grid_dim_z
<< " bdx: " << block_dim_x << " bdy: " << block_dim_y
@@ -674,11 +674,11 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ port::Status CUDADriver::LoadCubin(CudaContext* context,
const char *cubin_bytes,
CUmodule *module) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult result = cuModuleLoadFatBinary(module, cubin_bytes);
if (result != CUDA_SUCCESS) {
- return port::Status{port::error::INTERNAL,
- "failed to load in-memory CUBIN: " + ToString(result)};
+ return port::Status(port::error::INTERNAL,
+ "failed to load in-memory CUBIN: " + ToString(result));
}
return port::Status::OK();
@@ -691,7 +691,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
bool ret = true;
GetDriverExecutor()->Schedule([context, ptx_contents, module, &ret,
&notification]() {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
void *ptx_data = const_cast<char *>(ptx_contents);
static const unsigned int kLogBufferBytesLimit = 1024;
unsigned int error_log_buffer_bytes = kLogBufferBytesLimit;
@@ -757,7 +757,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ bool CUDADriver::SynchronousMemsetUint8(CudaContext* context,
CUdeviceptr location,
uint8 value, size_t size) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult res = cuMemsetD8(location, value, size);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to memset memory: " << ToString(res);
@@ -770,7 +770,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUdeviceptr location,
uint32 value,
size_t uint32_count) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult res = cuMemsetD32(location, value, uint32_count);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to memset memory: " << ToString(res);
@@ -784,7 +784,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
uint8 value,
size_t uint32_count,
CUstream stream) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult res = cuMemsetD8Async(location, value, uint32_count, stream);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
@@ -799,7 +799,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
uint32 value,
size_t uint32_count,
CUstream stream) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult res = cuMemsetD32Async(location, value, uint32_count, stream);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
@@ -877,9 +877,9 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
return device;
}
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
- port::StrCat("failed to get device for context: ", ToString(result))};
+ port::StrCat("failed to get device for context: ", ToString(result)));
}
/* static */ bool CUDADriver::CreateStream(CudaContext *context,
@@ -937,7 +937,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ void CUDADriver::DeviceDeallocate(CudaContext* context,
void *location) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUdeviceptr pointer = port::bit_cast<CUdeviceptr>(location);
CUresult res = cuMemFree(pointer);
if (res != CUDA_SUCCESS) {
@@ -950,7 +950,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ void *CUDADriver::HostAllocate(CudaContext *context,
uint64 bytes) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
void *host_mem = nullptr;
// "Portable" memory is visible to all CUDA contexts. Safe for our use model.
CUresult res = cuMemHostAlloc(&host_mem, bytes, CU_MEMHOSTALLOC_PORTABLE);
@@ -963,7 +963,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ void CUDADriver::HostDeallocate(CudaContext* context,
void *location) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult res = cuMemFreeHost(location);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "error deallocating host memory at " << location << ": "
@@ -973,7 +973,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ bool CUDADriver::HostRegister(CudaContext* context, void *location,
uint64 bytes) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
// "Portable" memory is visible to all CUDA contexts. Safe for our use model.
CUresult res =
cuMemHostRegister(location, bytes, CU_MEMHOSTREGISTER_PORTABLE);
@@ -987,7 +987,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ bool CUDADriver::HostUnregister(CudaContext* context,
void *location) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult res = cuMemHostUnregister(location);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "error unregistering host memory at " << location << ": "
@@ -1000,8 +1000,8 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ port::Status CUDADriver::DestroyEvent(CudaContext* context,
CUevent *event) {
if (*event == nullptr) {
- return port::Status{port::error::INVALID_ARGUMENT,
- "input event cannot be null"};
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "input event cannot be null");
}
ScopedActivateContext activated{context};
@@ -1013,15 +1013,15 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
return port::Status::OK();
case CUDA_ERROR_DEINITIALIZED:
case CUDA_ERROR_NOT_INITIALIZED:
- return port::Status{
+ return port::Status(
port::error::FAILED_PRECONDITION,
port::Printf("error destroying CUDA event in context %p: %s", context,
- ToString(res).c_str())};
+ ToString(res).c_str()));
default:
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
port::Printf("error destroying CUDA event in context %p: %s", context,
- ToString(res).c_str())};
+ ToString(res).c_str()));
}
}
@@ -1035,15 +1035,15 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
return port::Status::OK();
case CUDA_ERROR_DEINITIALIZED:
case CUDA_ERROR_NOT_INITIALIZED:
- return port::Status{
+ return port::Status(
port::error::FAILED_PRECONDITION,
port::Printf("error recording CUDA event on stream %p: %s", stream,
- ToString(res).c_str())};
+ ToString(res).c_str()));
default:
- return port::Status{
+ return port::Status(
port::error::INVALID_ARGUMENT,
port::Printf("error recording CUDA event on stream %p: %s", stream,
- ToString(res).c_str())};
+ ToString(res).c_str()));
}
}
@@ -1052,9 +1052,9 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
ScopedActivateContext activated{context};
CUresult res = cuEventQuery(event);
if (res != CUDA_SUCCESS && res != CUDA_ERROR_NOT_READY) {
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
- port::Printf("failed to query event: %s", ToString(res).c_str())};
+ port::Printf("failed to query event: %s", ToString(res).c_str()));
}
return res;
@@ -1084,7 +1084,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
/* static */ bool CUDADriver::WaitStreamOnEvent(CudaContext* context,
CUstream stream,
CUevent event) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult res = cuStreamWaitEvent(stream, event, 0 /* = flags */);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "could not wait stream on event: " << ToString(res);
@@ -1095,7 +1095,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
}
/* static */ bool CUDADriver::SynchronizeContext(CudaContext* context) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult res = cuCtxSynchronize();
if (res != CUDA_SUCCESS) {
LOG(ERROR) << "could not synchronize on CUDA context: " << ToString(res)
@@ -1141,7 +1141,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
void *host_dst,
CUdeviceptr gpu_src,
uint64 size) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult res = cuMemcpyDtoH(host_dst, gpu_src, size);
if (res != CUDA_SUCCESS) {
return port::InternalError(
@@ -1159,7 +1159,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUdeviceptr gpu_dst,
const void *host_src,
uint64 size) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult res = cuMemcpyHtoD(gpu_dst, host_src, size);
if (res != CUDA_SUCCESS) {
return port::InternalError(port::Printf(
@@ -1176,7 +1176,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUdeviceptr gpu_dst,
CUdeviceptr gpu_src,
uint64 size) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult res = cuMemcpyDtoD(gpu_dst, gpu_src, size);
if (res != CUDA_SUCCESS) {
return port::InternalError(port::Printf(
@@ -1194,7 +1194,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUdeviceptr gpu_src,
uint64 size,
CUstream stream) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult res = cuMemcpyDtoHAsync(host_dst, gpu_src, size, stream);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << port::Printf(
@@ -1214,7 +1214,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
const void *host_src,
uint64 size,
CUstream stream) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult res = cuMemcpyHtoDAsync(gpu_dst, host_src, size, stream);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << port::Printf(
@@ -1233,7 +1233,7 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
CUdeviceptr gpu_src,
uint64 size,
CUstream stream) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
CUresult result = cuMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream);
if (result != CUDA_SUCCESS) {
LOG(ERROR) << port::Printf(
@@ -1275,12 +1275,12 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
if (res == CUDA_SUCCESS) {
return port::Status::OK();
} else if (res == CUDA_ERROR_OUT_OF_MEMORY) {
- return port::Status{port::error::RESOURCE_EXHAUSTED,
- "could not create CUDA event: out of device memory"};
+ return port::Status(port::error::RESOURCE_EXHAUSTED,
+ "could not create CUDA event: out of device memory");
} else {
- return port::Status{
+ return port::Status(
port::error::FAILED_PRECONDITION,
- port::StrCat("could not create CUDA event: ", ToString(res))};
+ port::StrCat("could not create CUDA event: ", ToString(res)));
}
}
@@ -1308,10 +1308,10 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
return context;
}
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
port::StrCat("failed to query device pointer for context: ",
- ToString(result))};
+ ToString(result)));
}
/* static */ port::StatusOr<MemorySpace> CUDADriver::GetPointerMemorySpace(
@@ -1326,16 +1326,16 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
case CU_MEMORYTYPE_HOST:
return MemorySpace::kHost;
default:
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
- port::StrCat("unknown memory space provided by CUDA API: ", value)};
+ port::StrCat("unknown memory space provided by CUDA API: ", value));
}
}
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
port::StrCat("failed to query device pointer for memory space: ",
- ToString(result))};
+ ToString(result)));
}
/* static */ port::Status CUDADriver::GetPointerAddressRange(CUdeviceptr dptr,
@@ -1348,16 +1348,16 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
// We differentiate between "this pointer is unknown" (return here) and
// "there was an internal error while performing this operation" (return
// below).
- return port::Status{
+ return port::Status(
port::error::NOT_FOUND,
port::Printf("not a device pointer %p; %s",
- reinterpret_cast<void *>(dptr), ToString(result).c_str())};
+ reinterpret_cast<void *>(dptr), ToString(result).c_str()));
}
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
port::Printf("failed to get pointer into for device pointer %p; %s",
- reinterpret_cast<void *>(dptr), ToString(result).c_str())};
+ reinterpret_cast<void *>(dptr), ToString(result).c_str()));
}
/* static */ port::StatusOr<CUdevice> CUDADriver::GetPointerDevice(
@@ -1380,10 +1380,10 @@ CUDADriver::ContextGetSharedMemConfig(CudaContext* context) {
return port::Status::OK();
}
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
port::Printf("failed to get compute capability for device: %s; %d",
- ToString(result).c_str(), device)};
+ ToString(result).c_str(), device));
}
// Helper function that turns the integer output of cuDeviceGetAttribute to type
@@ -1394,10 +1394,10 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
int value = -1;
CUresult result = cuDeviceGetAttribute(&value, attribute, device);
if (result != CUDA_SUCCESS) {
- return port::Status{
+ return port::Status(
port::error::NOT_FOUND,
port::StrCat("could not retrieve CUDA device attribute (", attribute,
- "): ", ToString(result))};
+ "): ", ToString(result)));
}
T converted = value;
return converted;
@@ -1499,10 +1499,10 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
int val;
CUresult res = cuDeviceGetAttribute(&val, attribute, device);
if (res != CUDA_SUCCESS) {
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
port::Printf("failed to get device attribute %d for device %d: %s",
- attribute, device, ToString(res).c_str())};
+ attribute, device, ToString(res).c_str()));
}
return val;
}
@@ -1523,7 +1523,7 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
/* static */ bool CUDADriver::GetDeviceMemoryInfo(CudaContext* context,
int64 *free_out,
int64 *total_out) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
size_t free = 0;
size_t total = 0;
CUresult res = cuMemGetInfo(&free, &total);
@@ -1603,10 +1603,10 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
CUresult result = cuCtxEnablePeerAccess(to->context(), 0 /* = flags */);
if (result != CUDA_SUCCESS &&
result != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED) {
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
port::Printf("failed to enable peer access from %p to %p: %s", from, to,
- ToString(result).c_str())};
+ ToString(result).c_str()));
}
return port::Status::OK();
@@ -1615,16 +1615,16 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
/* static */ port::StatusOr<int> CUDADriver::GetMaxOccupiedBlocksPerCore(
CudaContext* context, CUfunction kernel, int threads_per_block,
size_t dynamic_shared_memory_bytes) {
- ScopedActivateContext activation{context};
+ ScopedActivateContext activation(context);
int max_blocks;
CUresult result = cuOccupancyMaxActiveBlocksPerMultiprocessor(
&max_blocks, kernel, threads_per_block, dynamic_shared_memory_bytes);
if (result != CUDA_SUCCESS) {
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
port::Printf("failed to calculate occupancy of kernel %p: %s", kernel,
- ToString(result).c_str())};
+ ToString(result).c_str()));
}
return max_blocks;
diff --git a/tensorflow/stream_executor/cuda/cuda_fft.cc b/tensorflow/stream_executor/cuda/cuda_fft.cc
index 5b34740f9f..013ca2d7f6 100644
--- a/tensorflow/stream_executor/cuda/cuda_fft.cc
+++ b/tensorflow/stream_executor/cuda/cuda_fft.cc
@@ -138,8 +138,8 @@ port::Status CUDAFftPlan::Initialize(
CUDAFftType(type), 1 /* = batch */);
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to create cuFFT 1d plan:" << ret;
- return port::Status{port::error::INTERNAL,
- "Failed to create cuFFT 1d plan."};
+ return port::Status(port::error::INTERNAL,
+ "Failed to create cuFFT 1d plan.");
}
return port::Status::OK();
case 2:
@@ -148,8 +148,8 @@ port::Status CUDAFftPlan::Initialize(
elem_count_[1], CUDAFftType(type));
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to create cuFFT 2d plan:" << ret;
- return port::Status{port::error::INTERNAL,
- "Failed to create cuFFT 2d plan."};
+ return port::Status(port::error::INTERNAL,
+ "Failed to create cuFFT 2d plan.");
}
return port::Status::OK();
case 3:
@@ -159,29 +159,29 @@ port::Status CUDAFftPlan::Initialize(
elem_count_[2], CUDAFftType(type));
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to create cuFFT 3d plan:" << ret;
- return port::Status{port::error::INTERNAL,
- "Failed to create cuFFT 3d plan."};
+ return port::Status(port::error::INTERNAL,
+ "Failed to create cuFFT 3d plan.");
}
return port::Status::OK();
default:
LOG(ERROR) << "Invalid rank value for cufftPlan. "
"Requested 1, 2, or 3, given: "
<< rank;
- return port::Status{port::error::INVALID_ARGUMENT,
- "cufftPlan only takes rank 1, 2, or 3."};
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "cufftPlan only takes rank 1, 2, or 3.");
}
} else {
ret = wrap::cufftCreate(parent, &plan_);
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to create cuFFT plan:" << ret;
- return port::Status{port::error::INTERNAL,
- "Failed to create cuFFT plan."};
+ return port::Status(port::error::INTERNAL,
+ "Failed to create cuFFT plan.");
}
ret = wrap::cufftSetAutoAllocation(parent, plan_, 0);
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to set auto allocation for cuFFT plan:" << ret;
- return port::Status{port::error::INTERNAL,
- "Failed to set auto allocation for cuFFT plan."};
+ return port::Status(port::error::INTERNAL,
+ "Failed to set auto allocation for cuFFT plan.");
}
switch (rank) {
case 1:
@@ -190,8 +190,8 @@ port::Status CUDAFftPlan::Initialize(
&scratch_size_bytes_);
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to make cuFFT 1d plan:" << ret;
- return port::Status{port::error::INTERNAL,
- "Failed to make cuFFT 1d plan."};
+ return port::Status(port::error::INTERNAL,
+ "Failed to make cuFFT 1d plan.");
}
break;
case 2:
@@ -200,8 +200,8 @@ port::Status CUDAFftPlan::Initialize(
&scratch_size_bytes_);
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to make cuFFT 2d plan:" << ret;
- return port::Status{port::error::INTERNAL,
- "Failed to make cuFFT 2d plan."};
+ return port::Status(port::error::INTERNAL,
+ "Failed to make cuFFT 2d plan.");
}
break;
case 3:
@@ -210,16 +210,16 @@ port::Status CUDAFftPlan::Initialize(
CUDAFftType(type), &scratch_size_bytes_);
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to make cuFFT 3d plan:" << ret;
- return port::Status{port::error::INTERNAL,
- "Failed to make cuFFT 3d plan."};
+ return port::Status(port::error::INTERNAL,
+ "Failed to make cuFFT 3d plan.");
}
break;
default:
LOG(ERROR) << "Invalid rank value for cufftPlan. "
"Requested 1, 2, or 3, given: "
<< rank;
- return port::Status{port::error::INVALID_ARGUMENT,
- "cufftPlan only takes rank 1, 2, or 3."};
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "cufftPlan only takes rank 1, 2, or 3.");
}
return UpdateScratchAllocator(stream, scratch_allocator);
}
@@ -233,23 +233,23 @@ port::Status CUDAFftPlan::Initialize(
output_distance, CUDAFftType(type), batch_count);
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to create cuFFT batched plan:" << ret;
- return port::Status{port::error::INTERNAL,
- "Failed to create cuFFT batched plan."};
+ return port::Status(port::error::INTERNAL,
+ "Failed to create cuFFT batched plan.");
}
} else {
auto ret = wrap::cufftCreate(parent, &plan_);
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to create cuFFT batched plan:" << ret;
- return port::Status{port::error::INTERNAL,
- "Failed to create cuFFT batched plan."};
+ return port::Status(port::error::INTERNAL,
+ "Failed to create cuFFT batched plan.");
}
ret = wrap::cufftSetAutoAllocation(parent, plan_, 0);
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to set auto allocation for cuFFT batched plan:"
<< ret;
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
- "Failed to set auto allocation for cuFFT batched plan."};
+ "Failed to set auto allocation for cuFFT batched plan.");
}
ret = wrap::cufftMakePlanMany(
parent, plan_, rank, elem_count_,
@@ -259,8 +259,8 @@ port::Status CUDAFftPlan::Initialize(
&scratch_size_bytes_);
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to make cuFFT batched plan:" << ret;
- return port::Status{port::error::INTERNAL,
- "Failed to make cuFFT batched plan."};
+ return port::Status(port::error::INTERNAL,
+ "Failed to make cuFFT batched plan.");
}
return UpdateScratchAllocator(stream, scratch_allocator);
}
@@ -293,8 +293,8 @@ port::Status CUDAFftPlan::UpdateScratchAllocator(
cufftResult_t ret = wrap::cufftSetWorkArea(parent_, plan_, scratch_.opaque());
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to set work area for cuFFT plan:" << ret;
- return port::Status{port::error::INTERNAL,
- "Failed to set work area for cuFFT plan."};
+ return port::Status(port::error::INTERNAL,
+ "Failed to set work area for cuFFT plan.");
}
return port::Status::OK();
}
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
index 7c87d33d21..f2be68bc42 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -609,10 +609,10 @@ port::Status CUDAExecutor::WaitForEvent(Stream *stream, Event *event) {
AsCUDAEvent(event)->cuda_event())) {
return port::Status::OK();
} else {
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
port::Printf("error recording waiting for CUDA event on stream %p",
- stream)};
+ stream));
}
}
diff --git a/tensorflow/stream_executor/cuda/cuda_platform.cc b/tensorflow/stream_executor/cuda/cuda_platform.cc
index 649224a20e..ebe4dcc904 100644
--- a/tensorflow/stream_executor/cuda/cuda_platform.cc
+++ b/tensorflow/stream_executor/cuda/cuda_platform.cc
@@ -124,9 +124,9 @@ port::StatusOr<StreamExecutor*> CudaPlatform::FirstExecutorForBus(
}
}
- return port::Status{
+ return port::Status(
port::error::NOT_FOUND,
- port::Printf("Executor for bus %d not found.", bus_ordinal)};
+ port::Printf("Executor for bus %d not found.", bus_ordinal));
}
Platform::Id CudaPlatform::id() const { return kCudaPlatformId; }
@@ -172,11 +172,11 @@ CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
this, MakeUnique<CUDAExecutor>(config.plugin_config));
auto init_status = executor->Init(config.ordinal, config.device_options);
if (!init_status.ok()) {
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
port::Printf(
"failed initializing StreamExecutor for CUDA device ordinal %d: %s",
- config.ordinal, init_status.ToString().c_str())};
+ config.ordinal, init_status.ToString().c_str()));
}
return std::move(executor);
diff --git a/tensorflow/stream_executor/cuda/cuda_rng.cc b/tensorflow/stream_executor/cuda/cuda_rng.cc
index e289e7ced5..88c4f15792 100644
--- a/tensorflow/stream_executor/cuda/cuda_rng.cc
+++ b/tensorflow/stream_executor/cuda/cuda_rng.cc
@@ -114,7 +114,7 @@ CUDARng::~CUDARng() {
}
bool CUDARng::Init() {
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
CHECK(rng_ == nullptr);
curandStatus_t ret =
@@ -150,7 +150,7 @@ constexpr bool ComplexIsConsecutiveFloats() {
template <typename T>
bool CUDARng::DoPopulateRandUniformInternal(Stream *stream,
DeviceMemory<T> *v) {
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
static_assert(ComplexIsConsecutiveFloats(),
"std::complex values are not stored as consecutive values");
@@ -209,7 +209,7 @@ bool CUDARng::DoPopulateRandGaussianInternal(Stream *stream, ElemT mean,
ElemT stddev,
DeviceMemory<ElemT> *v,
FuncT func) {
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
if (!SetStream(stream)) {
return false;
@@ -241,7 +241,7 @@ bool CUDARng::DoPopulateRandGaussian(Stream *stream, double mean, double stddev,
}
bool CUDARng::SetSeed(Stream *stream, const uint8 *seed, uint64 seed_bytes) {
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
CHECK(rng_ != nullptr);
if (!CheckSeed(seed, seed_bytes)) {
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 18606eb717..38abc66079 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -882,8 +882,8 @@ enum class ElementwiseOperation { kAdd, kMultiply };
string ElementwiseOperationString(ElementwiseOperation op);
-// A simple class representing the version of the backing library, to
-// workaround the "too perfect forwarding" issue in gcc6+ compilers.
+// A simple class representing the version of the backing library, to
+// workaround the "too perfect forwarding" issue in gcc6+ compilers.
// See PR#16309 and issue #18402 for links discussing the issue.
class VersionInfo {
public:
@@ -1051,10 +1051,8 @@ class DnnSupport {
// convolution result.
// scratch_allocator: un-owned, may-be-null object that may allocate scratch
// space in order to speed up the convolution operation.
- // algorithm: specifies which algorithm should be used for the
- // operation. If algorithm.is_default(), the system will pick an algorithm
- // by default. The coding of the algorithm is be interpretted by the
- // underlying implementation.
+ // algorithm_config: specifies which algorithm should be used for the
+ // operation.
// output_profile_result: the output profile result for this call. The
// profiling is only enabled when this is not nullptr.
//
@@ -1153,17 +1151,13 @@ class DnnSupport {
// convolution input.
// filter_descriptor: dimensions of the convolution filter.
// convolution_descriptor: stride of the convolution filter.
- // input. This can be DeviceMemory pointing to NULL only when activation_mode
- // is kNone.
// output_descriptor: dimensions of the output layer.
// output_data: un-owned device memory region in which to place the
// convolution result.
// scratch_allocator: un-owned, may-be-null object that may allocate scratch
// space in order to speed up the convolution operation.
- // algorithm: an integer to specify which algorithm should be used for the
- // operation. kDefaultAlgorithm means the system will pick an algorithm
- // by default. The coding of the algorithm is be interpreted by the
- // underlying implementation.
+ // algorithm_config: specifies which algorithm should be used for the
+ // operation.
// output_profile_result: the output profile result for this call. The
// profiling is only enabled when this is not nullptr.
//
@@ -1220,6 +1214,7 @@ class DnnSupport {
ProfileResult* output_profile_result) = 0;
// Return a list of algorithms supported by the forward convolution pass.
+ // cc_major and cc_minor are the compute capabilities of the device.
virtual bool GetConvolveAlgorithms(
bool with_winograd_nonfused, int cc_major, int cc_minor,
std::vector<AlgorithmDesc>* out_algorithms);
@@ -2036,8 +2031,8 @@ class DnnSupport {
const dnn::AlgorithmConfig& algorithm_config,
float dropout, uint64 seed,
ScratchAllocator* state_allocator) {
- return port::Status{port::error::UNIMPLEMENTED,
- "createRnnDescriptor is unimplemented"};
+ return port::Status(port::error::UNIMPLEMENTED,
+ "createRnnDescriptor is unimplemented");
}
// Create a RNN sequence descriptor that specifies either the input or output
@@ -2051,8 +2046,8 @@ class DnnSupport {
virtual port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
int data_size, dnn::DataType data_type) {
- return port::Status{port::error::UNIMPLEMENTED,
- "createRnnSequenceTensorDescriptor is unimplemented"};
+ return port::Status(port::error::UNIMPLEMENTED,
+ "createRnnSequenceTensorDescriptor is unimplemented");
}
// Create an RNN state descriptor that specifies the input or hidden state.
@@ -2060,8 +2055,8 @@ class DnnSupport {
virtual port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
dnn::DataType data_type) {
- return port::Status{port::error::UNIMPLEMENTED,
- "createRnnStateTensorDescriptor is unimplemented"};
+ return port::Status(port::error::UNIMPLEMENTED,
+ "createRnnStateTensorDescriptor is unimplemented");
}
// Enqueue a forward operation of the RNN model onto the stream.
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h
index 0c3991c151..e82f57569f 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.h
+++ b/tensorflow/stream_executor/host/host_gpu_executor.h
@@ -106,19 +106,19 @@ class HostExecutor : public internal::StreamExecutorInterface {
bool HostCallback(Stream *stream, std::function<void()> callback) override;
port::Status AllocateEvent(Event *event) override {
- return port::Status{port::error::UNIMPLEMENTED, ""};
+ return port::Status(port::error::UNIMPLEMENTED, "");
}
port::Status DeallocateEvent(Event *event) override {
- return port::Status{port::error::UNIMPLEMENTED, ""};
+ return port::Status(port::error::UNIMPLEMENTED, "");
}
port::Status RecordEvent(Stream *stream, Event *event) override {
- return port::Status{port::error::UNIMPLEMENTED, ""};
+ return port::Status(port::error::UNIMPLEMENTED, "");
}
port::Status WaitForEvent(Stream *stream, Event *event) override {
- return port::Status{port::error::UNIMPLEMENTED, ""};
+ return port::Status(port::error::UNIMPLEMENTED, "");
}
Event::Status PollForEventStatus(Event *event) override {
@@ -167,7 +167,7 @@ class HostExecutor : public internal::StreamExecutorInterface {
"Shared memory configuration is unsupported for host "
"executors."};
LOG(INFO) << error_msg;
- return port::Status{port::error::UNIMPLEMENTED, error_msg};
+ return port::Status(port::error::UNIMPLEMENTED, error_msg);
}
bool SupportsBlas() const override;
diff --git a/tensorflow/stream_executor/host/host_platform.cc b/tensorflow/stream_executor/host/host_platform.cc
index a652b08b4f..eeb6a06e3d 100644
--- a/tensorflow/stream_executor/host/host_platform.cc
+++ b/tensorflow/stream_executor/host/host_platform.cc
@@ -70,11 +70,11 @@ HostPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
this, MakeUnique<HostExecutor>(config.plugin_config));
auto init_status = executor->Init(config.ordinal, config.device_options);
if (!init_status.ok()) {
- return port::Status{
+ return port::Status(
port::error::INTERNAL,
port::Printf(
"failed initializing StreamExecutor for device ordinal %d: %s",
- config.ordinal, init_status.ToString().c_str())};
+ config.ordinal, init_status.ToString().c_str()));
}
return std::move(executor);
diff --git a/tensorflow/stream_executor/host_or_device_scalar.h b/tensorflow/stream_executor/host_or_device_scalar.h
index c9e3e14778..1f5d4b9260 100644
--- a/tensorflow/stream_executor/host_or_device_scalar.h
+++ b/tensorflow/stream_executor/host_or_device_scalar.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
#define TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
-#include "tensorflow/core/platform/logging.h"
#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/platform/logging.h"
namespace stream_executor {
diff --git a/tensorflow/stream_executor/kernel_spec.cc b/tensorflow/stream_executor/kernel_spec.cc
index f0a5785b72..902892af3f 100644
--- a/tensorflow/stream_executor/kernel_spec.cc
+++ b/tensorflow/stream_executor/kernel_spec.cc
@@ -93,7 +93,7 @@ const char *CudaPtxInMemory::default_text() const {
return nullptr;
}
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
auto ptx = ptx_by_compute_capability_.begin()->second;
// Check if there is an entry in decompressed ptx table.
@@ -127,7 +127,7 @@ const char *CudaPtxInMemory::text(int compute_capability_major,
return nullptr;
}
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
// Check if there is an entry in decompressed ptx table.
auto decompressed_ptx_iter = decompressed_ptx_.find(ptx_iter->second);
diff --git a/tensorflow/stream_executor/plugin_registry.cc b/tensorflow/stream_executor/plugin_registry.cc
index 7812703efd..c53685c57b 100644
--- a/tensorflow/stream_executor/plugin_registry.cc
+++ b/tensorflow/stream_executor/plugin_registry.cc
@@ -72,11 +72,11 @@ port::Status PluginRegistry::RegisterFactoryInternal(
mutex_lock lock{GetPluginRegistryMutex()};
if (factories->find(plugin_id) != factories->end()) {
- return port::Status{
+ return port::Status(
port::error::ALREADY_EXISTS,
port::Printf("Attempting to register factory for plugin %s when "
"one has already been registered",
- plugin_name.c_str())};
+ plugin_name.c_str()));
}
(*factories)[plugin_id] = factory;
@@ -92,9 +92,9 @@ port::StatusOr<FACTORY_TYPE> PluginRegistry::GetFactoryInternal(
if (iter == factories.end()) {
iter = generic_factories.find(plugin_id);
if (iter == generic_factories.end()) {
- return port::Status{
+ return port::Status(
port::error::NOT_FOUND,
- port::Printf("Plugin ID %p not registered.", plugin_id)};
+ port::Printf("Plugin ID %p not registered.", plugin_id));
}
}
@@ -212,10 +212,11 @@ bool PluginRegistry::HasFactory(Platform::Id platform_id,
plugin_id = default_factories_[platform_id].FACTORY_VAR; \
\
if (plugin_id == kNullPlugin) { \
- return port::Status{port::error::FAILED_PRECONDITION, \
- "No suitable " PLUGIN_STRING \
- " plugin registered. Have you linked in a " \
- PLUGIN_STRING "-providing plugin?"}; \
+ return port::Status( \
+ port::error::FAILED_PRECONDITION, \
+ "No suitable " PLUGIN_STRING \
+ " plugin registered. Have you linked in a " PLUGIN_STRING \
+ "-providing plugin?"); \
} else { \
VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, " \
<< plugin_names_[plugin_id]; \
@@ -231,9 +232,9 @@ bool PluginRegistry::HasFactory(Platform::Id platform_id,
PlatformKind platform_kind, PluginId plugin_id) { \
auto iter = platform_id_by_kind_.find(platform_kind); \
if (iter == platform_id_by_kind_.end()) { \
- return port::Status{port::error::FAILED_PRECONDITION, \
+ return port::Status(port::error::FAILED_PRECONDITION, \
port::Printf("Platform kind %d not registered.", \
- static_cast<int>(platform_kind))}; \
+ static_cast<int>(platform_kind))); \
} \
return GetFactory<PluginRegistry::FACTORY_TYPE>(iter->second, plugin_id); \
}
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 093f0c9306..4a98cfe164 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -276,7 +276,7 @@ Stream::~Stream() {
Stream &Stream::Init() {
VLOG_CALL();
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
CHECK_EQ(false, allocated_)
<< "stream appears to already have been initialized";
CHECK(!ok_) << "stream should be in !ok() state pre-initialization";
@@ -1899,7 +1899,7 @@ Stream &Stream::ThenCopyDevice2HostBuffer(
}
Stream *Stream::GetOrCreateSubStream() {
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
for (auto &stream : sub_streams_) {
if (stream.second) {
stream.second = false;
@@ -1916,7 +1916,7 @@ Stream *Stream::GetOrCreateSubStream() {
}
void Stream::ReturnSubStream(Stream *sub_stream) {
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
for (auto &stream : sub_streams_) {
if (stream.first.get() == sub_stream) {
stream.second = true;
@@ -4482,6 +4482,40 @@ Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
Stream &Stream::ThenBlasGemmBatched(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
+ int batch_count) {
+ return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
+ b, ldb, beta, c, ldc, batch_count,
+ /*scratch_allocator=*/nullptr);
+}
+
+Stream &Stream::ThenBlasGemmBatchedWithScratch(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
+ float, const port::ArraySlice<DeviceMemory<Eigen::half> *> &,
+ int, int, ScratchAllocator *>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
+ k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
+ scratch_allocator);
+}
+
+Stream &Stream::ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
@@ -5196,7 +5230,7 @@ port::Status Stream::BlockHostUntilDone() {
port::Status first_error;
{
// Wait until all active sub-streams have done their tasks.
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
for (auto &stream : sub_streams_) {
if (!stream.second) {
first_error.Update(stream.first->BlockHostUntilDone());
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 3d1b011c57..3da1b856d6 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -66,9 +66,6 @@ namespace dnn {
class BatchDescriptor;
class FilterDescriptor;
class ConvolutionDescriptor;
-class BatchDescriptor;
-class FilterDescriptor;
-class ConvolutionDescriptor;
class ProfileResult;
class AlgorithmDesc;
} // namespace dnn
@@ -1474,6 +1471,13 @@ class Stream {
blas::ProfileResult *output_profile_result);
// See BlasSupport::DoBlasGemmBatched.
+ Stream &ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
+ float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
+ int ldc, int batch_count);
Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, float alpha,
const port::ArraySlice<DeviceMemory<float> *> &a,
@@ -1508,6 +1512,13 @@ class Stream {
int batch_count);
Stream &ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
+ float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator);
+ Stream &ThenBlasGemmBatchedWithScratch(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
@@ -2005,7 +2016,7 @@ class Stream {
friend class ocl::CLBlas; // for parent_.
bool InErrorState() const LOCKS_EXCLUDED(mu_) {
- tf_shared_lock lock{mu_};
+ tf_shared_lock lock(mu_);
return !ok_;
}
@@ -2015,7 +2026,7 @@ class Stream {
if (operation_retcode) {
return;
}
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
ok_ = false;
}
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 20579790ef..eecd5bfe1f 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -232,7 +232,7 @@ void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
}
void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
- tf_shared_lock lock{mu_};
+ tf_shared_lock lock(mu_);
*records_out = mem_allocs_;
}
@@ -256,13 +256,13 @@ port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
string error_msg = port::Printf(
"Invalid shared memory config specified: %d", static_cast<int>(config));
LOG(ERROR) << error_msg;
- return port::Status{port::error::INVALID_ARGUMENT, error_msg};
+ return port::Status(port::error::INVALID_ARGUMENT, error_msg);
}
return implementation_->SetDeviceSharedMemoryConfig(config);
}
const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
if (device_description_ != nullptr) {
return *device_description_;
}
@@ -393,7 +393,7 @@ StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
}
dnn::DnnSupport *StreamExecutor::AsDnn() {
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
if (dnn_ != nullptr) {
return dnn_.get();
}
@@ -403,7 +403,7 @@ dnn::DnnSupport *StreamExecutor::AsDnn() {
}
blas::BlasSupport *StreamExecutor::AsBlas() {
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
if (blas_ != nullptr) {
return blas_.get();
}
@@ -413,7 +413,7 @@ blas::BlasSupport *StreamExecutor::AsBlas() {
}
fft::FftSupport *StreamExecutor::AsFft() {
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
if (fft_ != nullptr) {
return fft_.get();
}
@@ -423,7 +423,7 @@ fft::FftSupport *StreamExecutor::AsFft() {
}
rng::RngSupport *StreamExecutor::AsRng() {
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
if (rng_ != nullptr) {
return rng_.get();
}
@@ -582,12 +582,12 @@ port::Status StreamExecutor::SynchronousMemcpyD2H(
result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
if (!result.ok()) {
- result = port::Status{port::error::INTERNAL,
+ result = port::Status(port::error::INTERNAL,
port::Printf("failed to synchronously memcpy "
"device-to-host: device %p to host %p "
"size %lld: %s",
device_src.opaque(), host_dst, size,
- result.ToString().c_str())};
+ result.ToString().c_str()));
}
return result;
@@ -605,12 +605,12 @@ port::Status StreamExecutor::SynchronousMemcpyH2D(
result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
if (!result.ok()) {
- result = port::Status{
+ result = port::Status(
port::error::INTERNAL,
port::Printf("failed to synchronously memcpy host-to-device: host "
"%p to device %p size %lld: %s",
host_src, device_dst->opaque(), size,
- result.ToString().c_str())};
+ result.ToString().c_str()));
}
return result;
@@ -723,7 +723,7 @@ void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
mem_allocs_[opaque] = AllocRecord{
bytes, ""};
}
@@ -731,7 +731,7 @@ void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
void StreamExecutor::EraseAllocRecord(void *opaque) {
if (FLAGS_check_device_leaks && opaque != nullptr) {
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
LOG(ERROR) << "Deallocating unknown pointer: "
<< port::Printf("0x%p", opaque);
@@ -745,7 +745,7 @@ void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
{
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
if (listeners_.find(listener) != listeners_.end()) {
LOG(INFO) << "Attempt to register already-registered listener, "
<< listener;
@@ -759,7 +759,7 @@ void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
{
- mutex_lock lock{mu_};
+ mutex_lock lock(mu_);
if (listeners_.find(listener) == listeners_.end()) {
LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
return false;
@@ -776,7 +776,7 @@ void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) {
if (tracing_enabled_) {
{
// instance tracers held in a block to limit the lock lifetime.
- tf_shared_lock lock{mu_};
+ tf_shared_lock lock(mu_);
for (TraceListener *listener : listeners_) {
(listener->*trace_call)(std::forward<ArgsT>(args)...);
}
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index ab6b00f660..e426cf9931 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -177,6 +177,9 @@ class StreamExecutor {
//
// Resets the internal contents of mem to be null-representative, but this
// null-out effect should not be relied upon in client code.
+ //
+ // TODO(jlebar): Change this to accept a DeviceMemoryBase by value, see
+ // discussion in cl/195744342.
void Deallocate(DeviceMemoryBase *mem);
// Retrieves a mapping of active opaque device memory pointer to a string
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 14944ec25b..d71fd71bbd 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -959,15 +959,6 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs):
if not cuda_deps:
cuda_deps = []
- if 'linkstatic' not in kwargs or kwargs['linkstatic'] != 1:
- enable_text_relocation_linkopt = select({
- clean_dep("//tensorflow:darwin"): [],
- clean_dep("//tensorflow:windows"): [],
- "//conditions:default": ['-Wl,-z,notext'],})
- if 'linkopts' in kwargs:
- kwargs['linkopts'] += enable_text_relocation_linkopt
- else:
- kwargs['linkopts'] = enable_text_relocation_linkopt
native.cc_library(
deps=deps + if_cuda(cuda_deps + [
clean_dep("//tensorflow/core:cuda"),
@@ -1309,7 +1300,7 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]):
native.cc_library(
name=basename + "_gpu",
srcs=gpu_srcs,
- copts=_cuda_copts(),
+ copts=_cuda_copts() + if_tensorrt(["-DGOOGLE_TENSORRT=1"]),
deps=deps + if_cuda(cuda_deps))
cuda_deps.extend([":" + basename + "_gpu"])
@@ -1445,13 +1436,13 @@ def tf_py_wrap_cc(name,
extra_linkopts = select({
"@local_config_cuda//cuda:darwin": [
"-Wl,-exported_symbols_list",
- "%s.lds"%vscriptname,
+ "$(location %s.lds)"%vscriptname,
],
clean_dep("//tensorflow:windows"): [],
clean_dep("//tensorflow:windows_msvc"): [],
"//conditions:default": [
"-Wl,--version-script",
- "%s.lds"%vscriptname,
+ "$(location %s.lds)"%vscriptname,
]
})
extra_deps += select({
diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD
index a1c569951e..edc1e078bb 100644
--- a/tensorflow/tools/api/generator/BUILD
+++ b/tensorflow/tools/api/generator/BUILD
@@ -101,6 +101,7 @@ genrule(
"api/profiler/__init__.py",
"api/python_io/__init__.py",
"api/resource_loader/__init__.py",
+ "api/strings/__init__.py",
"api/saved_model/__init__.py",
"api/saved_model/builder/__init__.py",
"api/saved_model/constants/__init__.py",
diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py
index 788f6d3573..d72cb3b7dd 100644
--- a/tensorflow/tools/api/generator/create_python_api.py
+++ b/tensorflow/tools/api/generator/create_python_api.py
@@ -23,13 +23,18 @@ import collections
import os
import sys
-from tensorflow import python # pylint: disable=unused-import
+# Populate `sys.modules` which will be traversed to find TensorFlow modules.
+# Make sure your module gets imported in tensorflow/python/__init__.py for it
+# to be seen by this script.
+import tensorflow.python # pylint: disable=unused-import
+
from tensorflow.python.util import tf_decorator
_API_CONSTANTS_ATTR = '_tf_api_constants'
_API_NAMES_ATTR = '_tf_api_names'
_API_DIR = '/api/'
+_DEFAULT_MODULE_FILTER = 'tensorflow.'
_OUTPUT_MODULE = 'tensorflow.tools.api.generator.api'
_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
@@ -146,9 +151,12 @@ __all__.extend([s for s in _names_with_underscore])
return module_text_map
-def get_api_init_text():
+def get_api_init_text(module_filter):
"""Get a map from destination module to __init__.py code for that module.
+ Args:
+ module_filter: Substring used to filter module names to process.
+
Returns:
A dictionary where
key: (string) destination module (for e.g. tf or tf.consts).
@@ -162,7 +170,7 @@ def get_api_init_text():
for module in list(sys.modules.values()):
# Only look at tensorflow modules.
if (not module or not hasattr(module, '__name__') or
- 'tensorflow.' not in module.__name__):
+ module_filter not in module.__name__):
continue
# Do not generate __init__.py files for contrib modules for now.
if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'):
@@ -215,12 +223,13 @@ def get_api_init_text():
return module_code_builder.build()
-def create_api_files(output_files):
+def create_api_files(output_files, module_filter):
"""Creates __init__.py files for the Python API.
Args:
output_files: List of __init__.py file paths to create.
Each file must be under api/ directory.
+ module_filter: Substring used to filter module names to process.
Raises:
ValueError: if an output file is not under api/ directory,
@@ -248,7 +257,7 @@ def create_api_files(output_files):
os.makedirs(os.path.dirname(file_path))
open(file_path, 'a').close()
- module_text_map = get_api_init_text()
+ module_text_map = get_api_init_text(module_filter)
# Add imports to output files.
missing_output_files = []
@@ -270,10 +279,7 @@ def create_api_files(output_files):
',\n'.join(sorted(missing_output_files)))
-def main(output_files):
- create_api_files(output_files)
-
-if __name__ == '__main__':
+def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'outputs', metavar='O', type=str, nargs='+',
@@ -281,7 +287,12 @@ if __name__ == '__main__':
'semicolon-separated list of Python files that we expect this script to '
'output. If multiple files are passed in, then we assume output files '
'are listed directly as arguments.')
+ parser.add_argument(
+ '--module_filter', default=_DEFAULT_MODULE_FILTER, type=str,
+ help='Only processes modules with names containing this substring.'
+ )
args = parser.parse_args()
+
if len(args.outputs) == 1:
# If we only get a single argument, then it must be a file containing
# list of outputs.
@@ -289,4 +300,8 @@ if __name__ == '__main__':
outputs = [line.strip() for line in output_list_file.read().split(';')]
else:
outputs = args.outputs
- main(outputs)
+ create_api_files(outputs, args.module_filter)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tensorflow/tools/api/generator/create_python_api_test.py b/tensorflow/tools/api/generator/create_python_api_test.py
index 218c812045..5f1052249e 100644
--- a/tensorflow/tools/api/generator/create_python_api_test.py
+++ b/tensorflow/tools/api/generator/create_python_api_test.py
@@ -56,7 +56,8 @@ class CreatePythonApiTest(test.TestCase):
del sys.modules[_MODULE_NAME]
def testFunctionImportIsAdded(self):
- imports = create_python_api.get_api_init_text()
+ imports = create_python_api.get_api_init_text(
+ module_filter=create_python_api._DEFAULT_MODULE_FILTER)
expected_import = (
'from test.tensorflow.test_module import test_op as test_op1')
self.assertTrue(
@@ -69,14 +70,16 @@ class CreatePythonApiTest(test.TestCase):
msg='%s not in %s' % (expected_import, str(imports)))
def testClassImportIsAdded(self):
- imports = create_python_api.get_api_init_text()
+ imports = create_python_api.get_api_init_text(
+ module_filter=create_python_api._DEFAULT_MODULE_FILTER)
expected_import = 'from test.tensorflow.test_module import TestClass'
self.assertTrue(
'TestClass' in str(imports),
msg='%s not in %s' % (expected_import, str(imports)))
def testConstantIsAdded(self):
- imports = create_python_api.get_api_init_text()
+ imports = create_python_api.get_api_init_text(
+ module_filter=create_python_api._DEFAULT_MODULE_FILTER)
expected = 'from test.tensorflow.test_module import _TEST_CONSTANT'
self.assertTrue(expected in str(imports),
msg='%s not in %s' % (expected, str(imports)))
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt
index cbbd077c97..8e7e945ed1 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt
@@ -44,7 +44,7 @@ tf_class {
}
member_method {
name: "from_generator"
- argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 9a56ae8675..5cfb2fd2f0 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -45,7 +45,7 @@ tf_class {
}
member_method {
name: "from_generator"
- argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt
index e5ec824bb8..3327e5b274 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -45,7 +45,7 @@ tf_class {
}
member_method {
name: "from_generator"
- argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt
index 008239789c..9d59375282 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt
@@ -45,7 +45,7 @@ tf_class {
}
member_method {
name: "from_generator"
- argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt
index be9ba4ce85..cf22e39d4c 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt
@@ -24,6 +24,10 @@ tf_class {
argspec: "args=[\'self\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'weighted_sum\'], "
}
member_method {
+ name: "eval_dir"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "evaluate"
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt
index 91fca67b6b..a363bceae3 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt
@@ -24,6 +24,10 @@ tf_class {
argspec: "args=[\'self\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'weighted_sum\'], "
}
member_method {
+ name: "eval_dir"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "evaluate"
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 53a903c239..099838fa65 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -24,6 +24,10 @@ tf_class {
argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
}
member_method {
+ name: "eval_dir"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "evaluate"
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index ba17c90de2..87bd19a23a 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -24,6 +24,10 @@ tf_class {
argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
}
member_method {
+ name: "eval_dir"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "evaluate"
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt
index cd4f72fcf8..111914f643 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt
@@ -24,6 +24,10 @@ tf_class {
argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Adagrad\', \'<function relu instance>\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
}
member_method {
+ name: "eval_dir"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "evaluate"
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
index 303fd74a64..67e4ee02d0 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
@@ -24,6 +24,10 @@ tf_class {
argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'<function relu instance>\', \'None\', \'2\', \'None\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
}
member_method {
+ name: "eval_dir"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "evaluate"
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
index c97ea7969e..e1289b975e 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
@@ -24,6 +24,10 @@ tf_class {
argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'label_dimension\', \'weight_column\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'<function relu instance>\', \'None\', \'1\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
}
member_method {
+ name: "eval_dir"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "evaluate"
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt
index 4b5b5bf0e3..d030b2f51f 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt
@@ -24,6 +24,10 @@ tf_class {
argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Adagrad\', \'<function relu instance>\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
}
member_method {
+ name: "eval_dir"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "evaluate"
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt
index 42a0d59521..d72b576977 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt
@@ -23,6 +23,10 @@ tf_class {
argspec: "args=[\'self\', \'model_fn\', \'model_dir\', \'config\', \'params\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "eval_dir"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "evaluate"
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt
index 2de52d6c57..cb578759ee 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt
@@ -24,6 +24,10 @@ tf_class {
argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
}
member_method {
+ name: "eval_dir"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "evaluate"
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt
index e552f33720..fcd01bb663 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt
@@ -24,6 +24,10 @@ tf_class {
argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
}
member_method {
+ name: "eval_dir"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "evaluate"
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
index 3fc64dae88..acc3fc4c5b 100644
--- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
@@ -110,7 +110,7 @@ tf_module {
}
member_method {
name: "non_max_suppression"
- argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'None\'], "
+ argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'0.0\', \'None\'], "
}
member_method {
name: "pad_to_bounding_box"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
index cee76bdc1d..1568c3175b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
@@ -155,7 +155,7 @@ tf_class {
}
member_method {
name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
}
member_method {
name: "fit"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
index 02718cb5f9..10ddd5378b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
@@ -160,7 +160,7 @@ tf_class {
}
member_method {
name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
}
member_method {
name: "fit"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt
index ba2d083a75..c6149e8aa7 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt
@@ -450,7 +450,7 @@ tf_module {
}
member_method {
name: "softmax"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], "
}
member_method {
name: "softplus"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt
index 5838d58312..805b1c350e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'monitor\', \'factor\', \'patience\', \'verbose\', \'mode\', \'epsilon\', \'cooldown\', \'min_lr\'], varargs=None, keywords=None, defaults=[\'val_loss\', \'0.1\', \'10\', \'0\', \'auto\', \'0.0001\', \'0\', \'0\'], "
+ argspec: "args=[\'self\', \'monitor\', \'factor\', \'patience\', \'verbose\', \'mode\', \'min_delta\', \'cooldown\', \'min_lr\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0.1\', \'10\', \'0\', \'auto\', \'0.0001\', \'0\', \'0\'], "
}
member_method {
name: "in_cooldown"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-remote-monitor.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-remote-monitor.pbtxt
index 3d0acfed1d..1d80559a5e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-remote-monitor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-remote-monitor.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'root\', \'path\', \'field\', \'headers\'], varargs=None, keywords=None, defaults=[\'http://localhost:9000\', \'/publish/epoch/end/\', \'data\', \'None\'], "
+ argspec: "args=[\'self\', \'root\', \'path\', \'field\', \'headers\', \'send_as_json\'], varargs=None, keywords=None, defaults=[\'http://localhost:9000\', \'/publish/epoch/end/\', \'data\', \'None\', \'False\'], "
}
member_method {
name: "on_batch_begin"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
index 5e5b04c7c6..63123c905c 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
@@ -119,7 +119,7 @@ tf_class {
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\', \'initial_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\', \'initial_state\', \'constants\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt
index 82dc878a8c..6be64be6ea 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt
@@ -82,7 +82,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
index 54eda8ee21..c789e3fb97 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
index 815e34a48d..e2f97ece6f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
index dd78384005..bbb15950ae 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
@@ -155,7 +155,7 @@ tf_class {
}
member_method {
name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
}
member_method {
name: "fit"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
index 9fcb03f47e..8ba2aa00fb 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
@@ -160,7 +160,7 @@ tf_class {
}
member_method {
name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
}
member_method {
name: "fit"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.pbtxt
index 5a446c09d0..4d7a1519ce 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.utils.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.pbtxt
@@ -46,7 +46,7 @@ tf_module {
}
member_method {
name: "multi_gpu_model"
- argspec: "args=[\'model\', \'gpus\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'model\', \'gpus\', \'cpu_merge\', \'cpu_relocation\'], varargs=None, keywords=None, defaults=[\'True\', \'False\'], "
}
member_method {
name: "normalize"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt
index efa4419692..fa76e91d2c 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt
@@ -92,7 +92,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 0b12bc060e..823a36dad4 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -497,6 +497,10 @@ tf_module {
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
member {
+ name: "strings"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "summary"
mtype: "<type \'module\'>"
}
diff --git a/tensorflow/tools/api/golden/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/tensorflow.strings.pbtxt
new file mode 100644
index 0000000000..a3fbe95bba
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.strings.pbtxt
@@ -0,0 +1,7 @@
+path: "tensorflow.strings"
+tf_module {
+ member_method {
+ name: "regex_full_match"
+ argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages_remote.sh b/tensorflow/tools/ci_build/install/install_pip_packages_remote.sh
index 0beabcf5ef..721590f4d6 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages_remote.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages_remote.sh
@@ -20,8 +20,8 @@ if [ ! -f /usr/bin/x86_64-linux-gnu-gcc ]; then
ln -s /usr/local/bin/clang /usr/bin/x86_64-linux-gnu-gcc
fi
-pip2 install --upgrade setuptools
-pip3 install --upgrade setuptools
+easy_install -U pip==9.0.3
+easy_install3 -U pip==9.0.3
# The rest of the pip packages will be installed in
# `install_pip_packages.sh`
diff --git a/tensorflow/tools/graph_transforms/transform_graph.cc b/tensorflow/tools/graph_transforms/transform_graph.cc
index 3b9dd3dd2d..5cae8f8d8f 100644
--- a/tensorflow/tools/graph_transforms/transform_graph.cc
+++ b/tensorflow/tools/graph_transforms/transform_graph.cc
@@ -141,7 +141,7 @@ std::string ExpandPath(const std::string& path_string) {
return path_string;
}
- const char* home = NULL;
+ const char* home = nullptr;
std::string::size_type prefix = path_string.find_first_of('/');
if (path_string.length() == 1 || prefix == 1) {
// The value of $HOME, e.g., ~/foo
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index 3af79ee170..1a83c6e757 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -53,6 +53,7 @@ function main() {
PKG_NAME_FLAG=""
GPU_BUILD=0
NIGHTLY_BUILD=0
+ PROJECT_NAME=""
while true; do
if [[ "$1" == "--nightly_flag" ]]; then
NIGHTLY_BUILD=1
@@ -60,6 +61,12 @@ function main() {
GPU_BUILD=1
elif [[ "$1" == "--gpudirect" ]]; then
PKG_NAME_FLAG="--project_name tensorflow_gpudirect"
+ elif [[ "$1" == "--project_name" ]]; then
+ shift
+ if [[ -z "$1" ]]; then
+ break
+ fi
+ PROJECT_NAME="$1"
fi
shift
@@ -68,7 +75,9 @@ function main() {
fi
done
- if [[ ${NIGHTLY_BUILD} == "1" && ${GPU_BUILD} == "1" ]]; then
+ if [[ -n ${PROJECT_NAME} ]]; then
+ PKG_NAME_FLAG="--project_name ${PROJECT_NAME}"
+ elif [[ ${NIGHTLY_BUILD} == "1" && ${GPU_BUILD} == "1" ]]; then
PKG_NAME_FLAG="--project_name tf_nightly_gpu"
elif [[ ${NIGHTLY_BUILD} == "1" ]]; then
PKG_NAME_FLAG="--project_name tf_nightly"
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 0d10162c30..319878e1b5 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -33,6 +33,21 @@ from setuptools.dist import Distribution
# result for pip.
_VERSION = '1.8.0'
+_SHORT_DESCRIPTION = ('TensorFlow is an open source machine learning framework '
+ 'for everyone.')
+
+_LONG_DESCRIPTION = ('TensorFlow is an open source software library for high '
+ 'performance numerical computation. Its flexible '
+ 'architecture allows easy deployment of computation across'
+ ' a variety of platforms (CPUs, GPUs, TPUs), and from '
+ 'desktops to clusters of servers to mobile and edge '
+ 'devices. Originally developed by researchers and '
+ 'engineers from the Google Brain team within Google\'s AI '
+ 'organization, it comes with strong support for machine '
+ 'learning and deep learning and the flexible numerical '
+ 'computation core is used across many other scientific '
+ 'domains.')
+
REQUIRED_PACKAGES = [
'absl-py >= 0.1.6',
'astor >= 0.6.0',
@@ -214,8 +229,8 @@ headers = (list(find_files('*.h', 'tensorflow/core')) +
setup(
name=project_name,
version=_VERSION.replace('-', ''),
- description='TensorFlow helps the tensors flow',
- long_description='',
+ description=_SHORT_DESCRIPTION,
+ long_description=_LONG_DESCRIPTION,
url='https://www.tensorflow.org/',
author='Google Inc.',
author_email='opensource@google.com',
@@ -261,4 +276,5 @@ setup(
'Topic :: Software Development :: Libraries :: Python Modules',
],
license='Apache 2.0',
- keywords='tensorflow tensor machine learning',)
+ keywords='tensorflow tensor machine learning',
+)
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 01d424f20b..ea31df0e06 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -453,11 +453,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/7b8a8728fbd27086efbf3c57cf2bb35a557108c9.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/7b8a8728fbd27086efbf3c57cf2bb35a557108c9.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/a915f005cd63fd111bbca510236a5163a7e83576.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/a915f005cd63fd111bbca510236a5163a7e83576.tar.gz",
],
- sha256 = "c620859c3ae5818f316de4837f340b3bba1646f8add0a28e6d4da34ce47e3969",
- strip_prefix = "llvm-7b8a8728fbd27086efbf3c57cf2bb35a557108c9",
+ sha256 = "1c81ec0f843ea2c9369ccfa1c1b20023dc9a999bf075ae192fcb89e23896d929",
+ strip_prefix = "llvm-a915f005cd63fd111bbca510236a5163a7e83576",
build_file = clean_dep("//third_party/llvm:llvm.BUILD"),
)
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index ede7e31897..f3a80d3dd3 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -604,7 +604,7 @@ def _find_cupti_header_dir(repository_ctx, cuda_config):
for relative_path in CUPTI_HEADER_PATHS:
if repository_ctx.path("%s/%scupti.h" % (cuda_toolkit_path, relative_path)).exists:
return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
- auto_configure_fail("Cannot find cupti.h under %s" % cuda_toolkit_path)
+ auto_configure_fail("Cannot find cupti.h under %s" % ", ".join([cuda_toolkit_path + "/" + s for s in CUPTI_HEADER_PATHS]))
def _find_cupti_lib(repository_ctx, cuda_config):
diff --git a/third_party/libxsmm.BUILD b/third_party/libxsmm.BUILD
index 4124f2db63..78ed1f4e16 100644
--- a/third_party/libxsmm.BUILD
+++ b/third_party/libxsmm.BUILD
@@ -38,8 +38,8 @@ genrule(
":libxsmm_interface",
],
visibility = [
- "//tensorflow/core/kernels:__pkg__",
"//third_party/eigen3:__pkg__",
+ "//tensorflow/core/kernels:__pkg__",
],
)
diff --git a/tools/bazel.rc b/tools/bazel.rc
index 1c1e6afb65..03aa52da1f 100644
--- a/tools/bazel.rc
+++ b/tools/bazel.rc
@@ -1,8 +1,14 @@
+# By default, we don't distinct target and host platfroms.
+# When doing cross compilation, use --config=cross_compile to distinct them.
+build --distinct_host_configuration=false
+build:cross_compile --distinct_host_configuration=true
+
# Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the
# target CPU to build transient dependencies correctly. See
# https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu
build:android --crosstool_top=//external:android/crosstool
build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
+build:android --config=cross_compile
build:android_arm --config=android
build:android_arm --cpu=armeabi-v7a
build:android_arm --fat_apk_cpu=armeabi-v7a